Skip to content
Snippets Groups Projects
Commit d260f0a6 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added integrated test case (also example)

parent f9557aac
No related branches found
No related tags found
No related merge requests found
...@@ -3,8 +3,8 @@ import logging ...@@ -3,8 +3,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Generic, List, NamedTuple, Optional, Tuple, TypeVar, Union from typing import Dict, Generic, List, NamedTuple, Optional, Tuple, TypeVar, Union
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from opentuner.measurement.interface import MeasurementInterface from opentuner.measurement.interface import MeasurementInterface
from opentuner.resultsdb.models import Configuration, Result from opentuner.resultsdb.models import Configuration, Result
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
...@@ -162,12 +162,16 @@ class ApproxTuner(Generic[T]): ...@@ -162,12 +162,16 @@ class ApproxTuner(Generic[T]):
return [configs[i] for i in taken_idx] return [configs[i] for i in taken_idx]
def write_configs_to_dir(self, directory: PathLike): def write_configs_to_dir(self, directory: PathLike):
import os
from jsonpickle import encode from jsonpickle import encode
if not self.tuned: if not self.tuned:
raise RuntimeError( raise RuntimeError(
f"No tuning session has been run; call self.tune() first." f"No tuning session has been run; call self.tune() first."
) )
directory = Path(directory)
os.makedirs(directory, exist_ok=True)
encode(self.kept_configs, directory) encode(self.kept_configs, directory)
def plot_configs(self) -> plt.Figure: def plot_configs(self) -> plt.Figure:
......
import site
from pathlib import Path
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
site.addsitedir(Path(__file__).absolute().parent.parent)
from model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
)
tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500)
calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500)
module = VGG16Cifar10()
module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
app = TorchApp(
"TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
)
baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner()
tuner.tune(500, 2.1, 3.0, True, 50)
tuner.write_configs_to_dir("tuner_results/test")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment