Newer
Older
import site
from pathlib import Path
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
site.addsitedir(Path(__file__).absolute().parent.parent)
from model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
)
tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500)
calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500)
module = VGG16Cifar10()
module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
app = TorchApp(
"TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
)
baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner()
tuner.tune(500, 2.1, 3.0, True, 50)
tuner.write_configs_to_dir("tuner_results/test")