Skip to content
Snippets Groups Projects
Commit 5492d73e authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Updated test cases

parent d8da8106
No related branches found
No related tags found
No related merge requests found
...@@ -92,7 +92,7 @@ class ModeledApp(ApproxApp, abc.ABC): ...@@ -92,7 +92,7 @@ class ModeledApp(ApproxApp, abc.ABC):
def get_tuner(self) -> "ApproxModeledTuner": def get_tuner(self) -> "ApproxModeledTuner":
return ApproxModeledTuner(self) return ApproxModeledTuner(self)
def _init_model(self, model_name: str): def init_model(self, model_name: str):
self._name_to_model[model_name]._init() self._name_to_model[model_name]._init()
...@@ -371,10 +371,10 @@ class ApproxModeledTuner(ApproxTuner): ...@@ -371,10 +371,10 @@ class ApproxModeledTuner(ApproxTuner):
msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc) msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc)
if qos_model != "none": if qos_model != "none":
msg_logger.info("Initializing qos model %s", qos_model) msg_logger.info("Initializing qos model %s", qos_model)
self.app._init_model(qos_model) self.app.init_model(qos_model)
if perf_model != "none": if perf_model != "none":
msg_logger.info("Initializing performance model %s", perf_model) msg_logger.info("Initializing performance model %s", perf_model)
self.app._init_model(perf_model) self.app.init_model(perf_model)
ret = super().tune( ret = super().tune(
max_iter=max_iter, max_iter=max_iter,
qos_tuner_threshold=qos_tuner_threshold, qos_tuner_threshold=qos_tuner_threshold,
......
...@@ -70,7 +70,7 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -70,7 +70,7 @@ class TorchApp(ModeledApp, abc.ABC):
self.tensor_to_qos = tensor_to_qos self.tensor_to_qos = tensor_to_qos
self.combine_qos = combine_qos self.combine_qos = combine_qos
self.device = device self.device = device
self.model_storage = Path(model_storage_folder) self.model_storage = Path(model_storage_folder) if model_storage_folder else None
self.module = self.module.to(device) self.module = self.module.to(device)
self.midx = ModuleIndexer(module) self.midx = ModuleIndexer(module)
......
...@@ -14,14 +14,15 @@ msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True) ...@@ -14,14 +14,15 @@ msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
dataset = CIFAR.from_file( dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
) )
tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500) tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500)
calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500) calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500)
module = VGG16Cifar10() module = VGG16Cifar10()
module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
app = TorchApp( app = TorchApp(
"TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy, "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
model_storage_folder="tuner_results/vgg16_cifar10"
) )
baseline, _ = app.measure_qos_perf({}, False) baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner() tuner = app.get_tuner()
tuner.tune(500, 2.1, 3.0, True, 50) tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1")
tuner.write_configs_to_dir("tuner_results/test") tuner.dump_configs("tuner_results/test/configs.json")
...@@ -113,3 +113,45 @@ class TestModeledTuning(TorchAppSetUp): ...@@ -113,3 +113,45 @@ class TestModeledTuning(TorchAppSetUp):
perf_model="perf_linear", perf_model="perf_linear",
qos_model="qos_p2", qos_model="qos_p2",
) )
class TestModelSaving(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.model_path = "/tmp/test_models"
app = cls.get_app()
app.init_model("qos_p1")
app.init_model("qos_p2")
cls.app = cls.get_app()
@classmethod
def get_app(cls):
return TorchApp(
"TestTorchApp",
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
model_storage_folder=cls.model_path
)
def test_loading_p1(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_loading_p2(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment