Skip to content
Snippets Groups Projects
integrated_tuning.py 1.08 KiB
Newer Older
  • Learn to ignore specific revisions
  • import site
    from pathlib import Path
    
    import torch
    from torch.utils.data.dataloader import DataLoader
    from torch.utils.data.dataset import Subset
    
    site.addsitedir(Path(__file__).absolute().parent.parent)
    from model_zoo import CIFAR, VGG16Cifar10
    from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
    
    msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
    
    dataset = CIFAR.from_file(
        "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
    )
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    tune_loader = DataLoader(Subset(dataset, range(500)), batch_size=500)
    calib_loader = DataLoader(Subset(dataset, range(5000, 5500)), batch_size=500)
    
    module = VGG16Cifar10()
    module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
    app = TorchApp(
        "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        model_storage_folder="tuner_results/vgg16_cifar10"
    
    )
    baseline, _ = app.measure_qos_perf({}, False)
    tuner = app.get_tuner()
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    tuner.tune(100, 2.1, 3.0, True, 50, perf_model="perf_linear", qos_model="qos_p1")
    tuner.dump_configs("tuner_results/test/configs.json")