From 0d518aba992e36f7317b6180ec6bb20181a5041d Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 28 Jan 2021 20:36:38 -0600
Subject: [PATCH] Added progress bar support

---
 predtuner/torchapp.py      | 15 ++++++++++++---
 test/test_model_zoo_acc.py |  2 +-
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py
index 0e6fc28..4c9618a 100644
--- a/predtuner/torchapp.py
+++ b/predtuner/torchapp.py
@@ -9,8 +9,14 @@ from torch.utils.data.dataloader import DataLoader
 
 from ._logging import PathLike
 from .approxapp import ApproxKnob, BaselineKnob, KnobsT
-from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp,
-                         QoSModelP1, QoSModelP2)
+from .modeledapp import (
+    IPerfModel,
+    IQoSModel,
+    LinearPerfModel,
+    ModeledApp,
+    QoSModelP1,
+    QoSModelP2,
+)
 from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively
 
 
@@ -161,14 +167,17 @@ class TorchApp(ModeledApp, abc.ABC):
 
     @torch.no_grad()
     def empirical_measure_qos_perf(
-        self, with_approxes: KnobsT, is_test: bool
+        self, with_approxes: KnobsT, is_test: bool, progress: bool = False
     ) -> Tuple[float, float]:
         """Measure the QoS and performance of Module with given approximation
         empirically (i.e., by running the Module on the dataset)."""
 
         from time import time_ns
+        from tqdm import tqdm
 
         dataloader = self.test_loader if is_test else self.tune_loader
+        if progress:
+            dataloader = tqdm(dataloader)
         approxed = self._apply_knobs(with_approxes)
         qoses = []
 
diff --git a/test/test_model_zoo_acc.py b/test/test_model_zoo_acc.py
index 55f770d..e402881 100644
--- a/test/test_model_zoo_acc.py
+++ b/test/test_model_zoo_acc.py
@@ -34,5 +34,5 @@ class TestModelZooAcc(unittest.TestCase):
             )
             tune = DataLoader(dataset, batchsize)
             app = TorchApp("", network, tune, tune, get_knobs_from_file(), accuracy)
-            qos, _ = app.measure_qos_perf({}, False)
+            qos, _ = app.empirical_measure_qos_perf({}, False, True)
             self.assertAlmostEqual(qos, target_acc)
-- 
GitLab