Skip to content
Snippets Groups Projects
integrated_tuning.py 1006 B
Newer Older
  • Learn to ignore specific revisions
  • import site
    from pathlib import Path
    
    import torch
    from torch.utils.data.dataloader import DataLoader
    from torch.utils.data.dataset import Subset
    
    site.addsitedir(Path(__file__).absolute().parent.parent)
    from model_zoo import CIFAR, VGG16Cifar10
    from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
    
    msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
    
    dataset = CIFAR.from_file(
        "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
    )
    tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500)
    calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500)
    module = VGG16Cifar10()
    module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
    app = TorchApp(
        "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
    )
    baseline, _ = app.measure_qos_perf({}, False)
    tuner = app.get_tuner()
    tuner.tune(500, 2.1, 3.0, True, 50)
    tuner.write_configs_to_dir("tuner_results/test")