diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py
index 4c9618a96739c612a088d275f594e258a8ff223d..65b8b43163bfb5c4794d70bd498d631cc99e6df8 100644
--- a/predtuner/torchapp.py
+++ b/predtuner/torchapp.py
@@ -229,7 +229,7 @@ class TorchApp(ModeledApp, abc.ABC):
         return module_indexer.module
 
     def _sample_input(self):
-        inputs, _ = next(iter(self.tune_loader))
+        inputs, _ = next(iter(DataLoader(self.tune_loader.dataset, batch_size=1)))
         return inputs.to(self.device)