import site from pathlib import Path import torch from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Subset site.addsitedir(Path(__file__).absolute().parent.parent) from model_zoo import CIFAR, VGG16Cifar10 from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True) dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500) calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500) module = VGG16Cifar10() module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) app = TorchApp( "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy, ) baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() tuner.tune(500, 2.1, 3.0, True, 50) tuner.write_configs_to_dir("tuner_results/test")