import site from pathlib import Path import torch from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Subset site.addsitedir(Path(__file__).absolute().parent.parent) from model_zoo import CIFAR, VGG16Cifar10 from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True) dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500) calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500) module = VGG16Cifar10() module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) app = TorchApp( "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy, model_storage_folder="tuner_results/vgg16_cifar10" ) baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1") tuner.dump_configs("tuner_results/test/configs.json")