From 5492d73ec80d9bc3f98ff9674dfbc5652f218d70 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Mon, 25 Jan 2021 15:28:28 -0600
Subject: [PATCH] Updated test cases

---
 predtuner/modeledapp.py   |  6 +++---
 predtuner/torchapp.py     |  2 +-
 test/integrated_tuning.py |  9 +++++----
 test/test_torchapp.py     | 42 +++++++++++++++++++++++++++++++++++++++
 4 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py
index 6bd15bb..2f02418 100644
--- a/predtuner/modeledapp.py
+++ b/predtuner/modeledapp.py
@@ -92,7 +92,7 @@ class ModeledApp(ApproxApp, abc.ABC):
     def get_tuner(self) -> "ApproxModeledTuner":
         return ApproxModeledTuner(self)
 
-    def _init_model(self, model_name: str):
+    def init_model(self, model_name: str):
         self._name_to_model[model_name]._init()
 
 
@@ -371,10 +371,10 @@ class ApproxModeledTuner(ApproxTuner):
         msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc)
         if qos_model != "none":
             msg_logger.info("Initializing qos model %s", qos_model)
-            self.app._init_model(qos_model)
+            self.app.init_model(qos_model)
         if perf_model != "none":
             msg_logger.info("Initializing performance model %s", perf_model)
-            self.app._init_model(perf_model)
+            self.app.init_model(perf_model)
         ret = super().tune(
             max_iter=max_iter,
             qos_tuner_threshold=qos_tuner_threshold,
diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py
index 2e781ac..6cad524 100644
--- a/predtuner/torchapp.py
+++ b/predtuner/torchapp.py
@@ -70,7 +70,7 @@ class TorchApp(ModeledApp, abc.ABC):
         self.tensor_to_qos = tensor_to_qos
         self.combine_qos = combine_qos
         self.device = device
-        self.model_storage = Path(model_storage_folder)
+        self.model_storage = Path(model_storage_folder) if model_storage_folder else None
 
         self.module = self.module.to(device)
         self.midx = ModuleIndexer(module)
diff --git a/test/integrated_tuning.py b/test/integrated_tuning.py
index bfe1fb0..93c3704 100644
--- a/test/integrated_tuning.py
+++ b/test/integrated_tuning.py
@@ -14,14 +14,15 @@ msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
 dataset = CIFAR.from_file(
     "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
 )
-tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500)
-calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500)
+tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500)
+calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500)
 module = VGG16Cifar10()
 module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
 app = TorchApp(
     "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
+    model_storage_folder="tuner_results/vgg16_cifar10"
 )
 baseline, _ = app.measure_qos_perf({}, False)
 tuner = app.get_tuner()
-tuner.tune(500, 2.1, 3.0, True, 50)
-tuner.write_configs_to_dir("tuner_results/test")
+tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1")
+tuner.dump_configs("tuner_results/test/configs.json")
diff --git a/test/test_torchapp.py b/test/test_torchapp.py
index 12c9121..2b4650b 100644
--- a/test/test_torchapp.py
+++ b/test/test_torchapp.py
@@ -113,3 +113,45 @@ class TestModeledTuning(TorchAppSetUp):
             perf_model="perf_linear",
             qos_model="qos_p2",
         )
+
+
+class TestModelSaving(TorchAppSetUp):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls.baseline, _ = cls.app.measure_qos_perf({}, False)
+        cls.model_path = "/tmp/test_models"
+        app = cls.get_app()
+        app.init_model("qos_p1")
+        app.init_model("qos_p2")
+        cls.app = cls.get_app()
+
+    @classmethod
+    def get_app(cls):
+        return TorchApp(
+            "TestTorchApp",
+            cls.module,
+            DataLoader(cls.dataset, batch_size=500),
+            DataLoader(cls.dataset, batch_size=500),
+            get_knobs_from_file(),
+            accuracy,
+            model_storage_folder=cls.model_path
+        )
+
+    def test_loading_p1(self):
+        self.app.get_tuner().tune(
+            100,
+            3.0,
+            is_threshold_relative=True,
+            perf_model="perf_linear",
+            qos_model="qos_p1",
+        )
+
+    def test_loading_p2(self):
+        self.app.get_tuner().tune(
+            100,
+            3.0,
+            is_threshold_relative=True,
+            perf_model="perf_linear",
+            qos_model="qos_p2",
+        )
-- 
GitLab