From 5492d73ec80d9bc3f98ff9674dfbc5652f218d70 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Mon, 25 Jan 2021 15:28:28 -0600 Subject: [PATCH] Updated test cases --- predtuner/modeledapp.py | 6 +++--- predtuner/torchapp.py | 2 +- test/integrated_tuning.py | 9 +++++---- test/test_torchapp.py | 42 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py index 6bd15bb..2f02418 100644 --- a/predtuner/modeledapp.py +++ b/predtuner/modeledapp.py @@ -92,7 +92,7 @@ class ModeledApp(ApproxApp, abc.ABC): def get_tuner(self) -> "ApproxModeledTuner": return ApproxModeledTuner(self) - def _init_model(self, model_name: str): + def init_model(self, model_name: str): self._name_to_model[model_name]._init() @@ -371,10 +371,10 @@ class ApproxModeledTuner(ApproxTuner): msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc) if qos_model != "none": msg_logger.info("Initializing qos model %s", qos_model) - self.app._init_model(qos_model) + self.app.init_model(qos_model) if perf_model != "none": msg_logger.info("Initializing performance model %s", perf_model) - self.app._init_model(perf_model) + self.app.init_model(perf_model) ret = super().tune( max_iter=max_iter, qos_tuner_threshold=qos_tuner_threshold, diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index 2e781ac..6cad524 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -70,7 +70,7 @@ class TorchApp(ModeledApp, abc.ABC): self.tensor_to_qos = tensor_to_qos self.combine_qos = combine_qos self.device = device - self.model_storage = Path(model_storage_folder) + self.model_storage = Path(model_storage_folder) if model_storage_folder else None self.module = self.module.to(device) self.midx = ModuleIndexer(module) diff --git a/test/integrated_tuning.py b/test/integrated_tuning.py index bfe1fb0..93c3704 100644 --- a/test/integrated_tuning.py +++ b/test/integrated_tuning.py @@ -14,14 +14,15 @@ msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True) dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) -tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500) -calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500) +tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500) +calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500) module = VGG16Cifar10() module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) app = TorchApp( "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy, + model_storage_folder="tuner_results/vgg16_cifar10" ) baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() -tuner.tune(500, 2.1, 3.0, True, 50) -tuner.write_configs_to_dir("tuner_results/test") +tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1") +tuner.dump_configs("tuner_results/test/configs.json") diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 12c9121..2b4650b 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -113,3 +113,45 @@ class TestModeledTuning(TorchAppSetUp): perf_model="perf_linear", qos_model="qos_p2", ) + + +class TestModelSaving(TorchAppSetUp): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.baseline, _ = cls.app.measure_qos_perf({}, False) + cls.model_path = "/tmp/test_models" + app = cls.get_app() + app.init_model("qos_p1") + app.init_model("qos_p2") + cls.app = cls.get_app() + + @classmethod + def get_app(cls): + return TorchApp( + "TestTorchApp", + cls.module, + DataLoader(cls.dataset, batch_size=500), + DataLoader(cls.dataset, batch_size=500), + get_knobs_from_file(), + accuracy, + model_storage_folder=cls.model_path + ) + + def test_loading_p1(self): + self.app.get_tuner().tune( + 100, + 3.0, + is_threshold_relative=True, + perf_model="perf_linear", + qos_model="qos_p1", + ) + + def test_loading_p2(self): + self.app.get_tuner().tune( + 100, + 3.0, + is_threshold_relative=True, + perf_model="perf_linear", + qos_model="qos_p2", + ) -- GitLab