diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py index 6bd15bb23cd5775bc462a9991239134c0a58d76e..2f02418da78a3505d1963e8f5cba68d4327e0278 100644 --- a/predtuner/modeledapp.py +++ b/predtuner/modeledapp.py @@ -92,7 +92,7 @@ class ModeledApp(ApproxApp, abc.ABC): def get_tuner(self) -> "ApproxModeledTuner": return ApproxModeledTuner(self) - def _init_model(self, model_name: str): + def init_model(self, model_name: str): self._name_to_model[model_name]._init() @@ -371,10 +371,10 @@ class ApproxModeledTuner(ApproxTuner): msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc) if qos_model != "none": msg_logger.info("Initializing qos model %s", qos_model) - self.app._init_model(qos_model) + self.app.init_model(qos_model) if perf_model != "none": msg_logger.info("Initializing performance model %s", perf_model) - self.app._init_model(perf_model) + self.app.init_model(perf_model) ret = super().tune( max_iter=max_iter, qos_tuner_threshold=qos_tuner_threshold, diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index 2e781acc4c7aef325a01e7688ddfd0b4a60acb69..6cad524cc8cb0293867d70820e71f0c8acccb861 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -70,7 +70,7 @@ class TorchApp(ModeledApp, abc.ABC): self.tensor_to_qos = tensor_to_qos self.combine_qos = combine_qos self.device = device - self.model_storage = Path(model_storage_folder) + self.model_storage = Path(model_storage_folder) if model_storage_folder else None self.module = self.module.to(device) self.midx = ModuleIndexer(module) diff --git a/test/integrated_tuning.py b/test/integrated_tuning.py index bfe1fb08c6d280d8181258957ccbaa6c79ae9a1d..93c370449bac50ed1a438e801851dab1b890c6d4 100644 --- a/test/integrated_tuning.py +++ b/test/integrated_tuning.py @@ -14,14 +14,15 @@ msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True) dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) -tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500) -calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500) +tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500) +calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500) module = VGG16Cifar10() module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) app = TorchApp( "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy, + model_storage_folder="tuner_results/vgg16_cifar10" ) baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() -tuner.tune(500, 2.1, 3.0, True, 50) -tuner.write_configs_to_dir("tuner_results/test") +tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1") +tuner.dump_configs("tuner_results/test/configs.json") diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 12c9121c9c4e8be01fb3c89b862780b6ae601c1f..2b4650bef46c08fdf2b816b5272a913e4281d476 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -113,3 +113,45 @@ class TestModeledTuning(TorchAppSetUp): perf_model="perf_linear", qos_model="qos_p2", ) + + +class TestModelSaving(TorchAppSetUp): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.baseline, _ = cls.app.measure_qos_perf({}, False) + cls.model_path = "/tmp/test_models" + app = cls.get_app() + app.init_model("qos_p1") + app.init_model("qos_p2") + cls.app = cls.get_app() + + @classmethod + def get_app(cls): + return TorchApp( + "TestTorchApp", + cls.module, + DataLoader(cls.dataset, batch_size=500), + DataLoader(cls.dataset, batch_size=500), + get_knobs_from_file(), + accuracy, + model_storage_folder=cls.model_path + ) + + def test_loading_p1(self): + self.app.get_tuner().tune( + 100, + 3.0, + is_threshold_relative=True, + perf_model="perf_linear", + qos_model="qos_p1", + ) + + def test_loading_p2(self): + self.app.get_tuner().tune( + 100, + 3.0, + is_threshold_relative=True, + perf_model="perf_linear", + qos_model="qos_p2", + )