import site from pathlib import Path import torch from torch.nn.modules.module import Module from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Subset site.addsitedir(Path(__file__).absolute().parent.parent) import model_zoo as net from predtuner import TorchApp, accuracy, get_knobs_from_file def load_from_default_path(cls, prefix: str): return cls.from_file(f"{prefix}/input.bin", f"{prefix}/labels.bin") mnist = load_from_default_path(net.MNIST, "model_data/mnist") cifar10 = load_from_default_path(net.CIFAR, "model_data/cifar10") cifar100 = load_from_default_path(net.CIFAR, "model_data/cifar100") imagenet = load_from_default_path(net.ImageNet, "model_data/imagenet") networks_in_folder = { "lenet_mnist": (net.LeNet, mnist), "alexnet_cifar10": (net.AlexNet, cifar10), "alexnet2_cifar10": (net.AlexNet2, cifar10), "vgg16_cifar10": (net.VGG16Cifar10, cifar10), "vgg16_cifar100": (net.VGG16Cifar100, cifar100), } for name, (cls, dataset) in networks_in_folder.items(): network: Module = cls() network.load_state_dict(torch.load(f"model_data/{name}.pth.tar")) d1, d2 = DataLoader(Subset(dataset, range(5000, 10000)), 1), DataLoader(dataset, 1) app = TorchApp("", network, d1, d2, get_knobs_from_file(), accuracy) qos, _ = app.measure_qos_perf({}, False) print(f"{name} -> {qos}")