Skip to content
Snippets Groups Projects
Commit 82ca1653 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added pareto tool and test

parent fb889b56
No related branches found
No related tags found
No related merge requests found
from typing import List
import numpy as np
def _find_distance_to(points: np.ndarray, ref_points: np.ndarray) -> np.ndarray:
n_ref = len(ref_points)
if n_ref == 0:
return np.zeros(0)
if n_ref == 1:
return np.linalg.norm(points - ref_points, axis=1)
ref_points = np.array(sorted(ref_points, key=lambda p: p[0]))
px = points.T[0]
rx = ref_points.T[0]
local_unit_vecs = ref_points[1:] - ref_points[:-1]
dists = []
bins = np.digitize(px, rx) - 1
for point, left_ref_p in zip(points, bins):
if left_ref_p == -1:
left_ref_p = 0
to_left_ref = ref_points[left_ref_p] - point
local_unit_vec = local_unit_vecs[-1] if left_ref_p >= n_ref - 1 else local_unit_vecs[left_ref_p]
projection = np.dot(local_unit_vec, to_left_ref) / np.linalg.norm(local_unit_vec)
dist = np.sqrt(np.linalg.norm(to_left_ref) ** 2 - projection ** 2)
dists.append(dist)
return np.array(dists)
def is_pareto_efficient(points: np.ndarray, take_n: int = None) -> List[int]:
is_pareto = np.ones(points.shape[0], dtype=bool)
for idx, c in enumerate(points):
if is_pareto[idx]:
# Keep any point with a higher value
is_pareto[is_pareto] = np.any(points[is_pareto] > c, axis=1)
is_pareto[idx] = True # And keep self
non_pareto = np.logical_not(is_pareto)
pareto_idx = is_pareto.nonzero()[0]
non_pareto_idx = non_pareto.nonzero()[0]
non_pareto_dist_to_pareto = _find_distance_to(points[non_pareto], points[is_pareto])
dist_order = np.argsort(non_pareto_dist_to_pareto)
take_n_non_pareto = 0 if take_n is None else take_n - len(pareto_idx)
dist_order = dist_order[:take_n_non_pareto]
taken_non_pareto_idx = non_pareto_idx[dist_order]
return pareto_idx.tolist() + taken_non_pareto_idx.tolist()
......@@ -3,11 +3,13 @@ import logging
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import matplotlib.pyplot as plt
from opentuner.measurement.interface import MeasurementInterface
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
from ._logging import override_opentuner_config
from ._pareto import is_pareto_efficient
msg_logger = logging.getLogger(__name__)
KnobsT = Dict[str, str]
......@@ -112,6 +114,12 @@ class ApproxTuner:
self.keep_threshold = qos_keep_threshold
return self.kept_configs
def take_best_configs(self, n: Optional[int] = None) -> List[Config]:
configs = self.kept_configs
points = np.array([[c.perf, c.qos] for c in configs])
taken_idx = is_pareto_efficient(points, take_n=n)
return [configs[i] for i in taken_idx]
def write_configs_to_dir(self, directory: PathLike):
from jsonpickle import encode
......
......@@ -10,7 +10,7 @@ from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
class TestTorchApp(unittest.TestCase):
class TestTorchAppInit(unittest.TestCase):
def setUp(self):
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
......@@ -18,9 +18,7 @@ class TestTorchApp(unittest.TestCase):
self.dataset = Subset(dataset, range(100))
self.module = VGG16Cifar10()
self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
def get_app(self):
return TorchApp(
self.app = TorchApp(
"TestTorchApp",
self.module,
DataLoader(self.dataset, batch_size=500),
......@@ -29,11 +27,10 @@ class TestTorchApp(unittest.TestCase):
accuracy,
)
def test_init(self):
app = self.get_app()
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
self.assertEqual(len(n_knobs), 34)
for op_name, op in app.midx.name_to_module.items():
for op_name, op in self.app.midx.name_to_module.items():
if isinstance(op, Conv2d):
nknob = 56
elif isinstance(op, Linear):
......@@ -43,15 +40,25 @@ class TestTorchApp(unittest.TestCase):
self.assertEqual(n_knobs[op_name], nknob)
def test_baseline_qos(self):
app = self.get_app()
qos, _ = app.measure_qos_perf({}, False)
qos, _ = self.app.measure_qos_perf({}, False)
self.assertAlmostEqual(qos, 88.0)
class TestTorchAppTuner(TestTorchAppInit):
def setUp(self):
super().setUp()
self.baseline, _ = self.app.measure_qos_perf({}, False)
self.tuner = self.app.get_tuner()
self.tuner.tune(100, self.baseline - 3.0)
def test_tuning(self):
app = self.get_app()
baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner()
tuner.tune(100, baseline - 3.0)
configs = tuner.kept_configs
configs = self.tuner.kept_configs
for conf in configs:
self.assertTrue(conf.qos > baseline - 3.0)
self.assertTrue(conf.qos > self.baseline - 3.0)
def test_pareto(self):
configs = self.tuner.take_best_configs()
for c1 in configs:
self.assertFalse(
any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment