diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py index bf1da0270b7b98dc3405bb3fd3122809d8c21cc0..59925541bb767280f8e70618822e9774d1006b18 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py @@ -139,11 +139,14 @@ class ModelExporter: import numpy as np from torch.utils.data import DataLoader + def link_from_to(from_: PathLike, to: PathLike): + from_, to = Path(from_), Path(to) + from_.unlink(missing_ok=True) + from_.symlink_to(to.absolute()) + if isinstance(dataset, BinDataset): - input_filename.unlink(missing_ok=True) - labels_filename.unlink(missing_ok=True) - Path(input_filename).symlink_to(dataset.input_file) - Path(labels_filename).symlink_to(dataset.labels_file) + link_from_to(input_filename, dataset.input_file) + link_from_to(labels_filename, dataset.labels_file) return inputs, labels = zip(*iter(DataLoader(dataset))) inputs = np.stack(inputs, axis=0)