diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index 4c9618a96739c612a088d275f594e258a8ff223d..65b8b43163bfb5c4794d70bd498d631cc99e6df8 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -229,7 +229,7 @@ class TorchApp(ModeledApp, abc.ABC): return module_indexer.module def _sample_input(self): - inputs, _ = next(iter(self.tune_loader)) + inputs, _ = next(iter(DataLoader(self.tune_loader.dataset, batch_size=1))) return inputs.to(self.device)