diff --git a/predtuner/_pareto.py b/predtuner/_pareto.py new file mode 100644 index 0000000000000000000000000000000000000000..4396ab1cc9d6cd02b3c0d0595a55085d8ea3d064 --- /dev/null +++ b/predtuner/_pareto.py @@ -0,0 +1,45 @@ +from typing import List + +import numpy as np + + +def _find_distance_to(points: np.ndarray, ref_points: np.ndarray) -> np.ndarray: + n_ref = len(ref_points) + if n_ref == 0: + return np.zeros(0) + if n_ref == 1: + return np.linalg.norm(points - ref_points, axis=1) + ref_points = np.array(sorted(ref_points, key=lambda p: p[0])) + px = points.T[0] + rx = ref_points.T[0] + local_unit_vecs = ref_points[1:] - ref_points[:-1] + dists = [] + bins = np.digitize(px, rx) - 1 + for point, left_ref_p in zip(points, bins): + if left_ref_p == -1: + left_ref_p = 0 + to_left_ref = ref_points[left_ref_p] - point + local_unit_vec = local_unit_vecs[-1] if left_ref_p >= n_ref - 1 else local_unit_vecs[left_ref_p] + projection = np.dot(local_unit_vec, to_left_ref) / np.linalg.norm(local_unit_vec) + dist = np.sqrt(np.linalg.norm(to_left_ref) ** 2 - projection ** 2) + dists.append(dist) + return np.array(dists) + + +def is_pareto_efficient(points: np.ndarray, take_n: int = None) -> List[int]: + is_pareto = np.ones(points.shape[0], dtype=bool) + for idx, c in enumerate(points): + if is_pareto[idx]: + # Keep any point with a higher value + is_pareto[is_pareto] = np.any(points[is_pareto] > c, axis=1) + is_pareto[idx] = True # And keep self + non_pareto = np.logical_not(is_pareto) + pareto_idx = is_pareto.nonzero()[0] + non_pareto_idx = non_pareto.nonzero()[0] + + non_pareto_dist_to_pareto = _find_distance_to(points[non_pareto], points[is_pareto]) + dist_order = np.argsort(non_pareto_dist_to_pareto) + take_n_non_pareto = 0 if take_n is None else take_n - len(pareto_idx) + dist_order = dist_order[:take_n_non_pareto] + taken_non_pareto_idx = non_pareto_idx[dist_order] + return pareto_idx.tolist() + taken_non_pareto_idx.tolist() diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index e06c81fd2898be38677a69df9099a3a2ce72b97a..e2800fe97d764b9a9f39a10a31151bb7ac5d2ae2 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -3,11 +3,13 @@ import logging from pathlib import Path from typing import Dict, List, NamedTuple, Optional, Tuple, Union +import numpy as np import matplotlib.pyplot as plt from opentuner.measurement.interface import MeasurementInterface from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter from ._logging import override_opentuner_config +from ._pareto import is_pareto_efficient msg_logger = logging.getLogger(__name__) KnobsT = Dict[str, str] @@ -112,6 +114,12 @@ class ApproxTuner: self.keep_threshold = qos_keep_threshold return self.kept_configs + def take_best_configs(self, n: Optional[int] = None) -> List[Config]: + configs = self.kept_configs + points = np.array([[c.perf, c.qos] for c in configs]) + taken_idx = is_pareto_efficient(points, take_n=n) + return [configs[i] for i in taken_idx] + def write_configs_to_dir(self, directory: PathLike): from jsonpickle import encode diff --git a/test/test_torchapp.py b/test/test_torchapp.py index c16af2fbfb92fd9ffcf5c036da8583c1ef0d021d..75fb06b573adb40191cc58cb9b1022dd4041ef9b 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -10,7 +10,7 @@ from torch.utils.data.dataset import Subset msg_logger = config_pylogger(output_dir="/tmp", verbose=True) -class TestTorchApp(unittest.TestCase): +class TestTorchAppInit(unittest.TestCase): def setUp(self): dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" @@ -18,9 +18,7 @@ class TestTorchApp(unittest.TestCase): self.dataset = Subset(dataset, range(100)) self.module = VGG16Cifar10() self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) - - def get_app(self): - return TorchApp( + self.app = TorchApp( "TestTorchApp", self.module, DataLoader(self.dataset, batch_size=500), @@ -29,11 +27,10 @@ class TestTorchApp(unittest.TestCase): accuracy, ) - def test_init(self): - app = self.get_app() - n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()} + def test_knobs(self): + n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()} self.assertEqual(len(n_knobs), 34) - for op_name, op in app.midx.name_to_module.items(): + for op_name, op in self.app.midx.name_to_module.items(): if isinstance(op, Conv2d): nknob = 56 elif isinstance(op, Linear): @@ -43,15 +40,25 @@ class TestTorchApp(unittest.TestCase): self.assertEqual(n_knobs[op_name], nknob) def test_baseline_qos(self): - app = self.get_app() - qos, _ = app.measure_qos_perf({}, False) + qos, _ = self.app.measure_qos_perf({}, False) self.assertAlmostEqual(qos, 88.0) + +class TestTorchAppTuner(TestTorchAppInit): + def setUp(self): + super().setUp() + self.baseline, _ = self.app.measure_qos_perf({}, False) + self.tuner = self.app.get_tuner() + self.tuner.tune(100, self.baseline - 3.0) + def test_tuning(self): - app = self.get_app() - baseline, _ = app.measure_qos_perf({}, False) - tuner = app.get_tuner() - tuner.tune(100, baseline - 3.0) - configs = tuner.kept_configs + configs = self.tuner.kept_configs for conf in configs: - self.assertTrue(conf.qos > baseline - 3.0) + self.assertTrue(conf.qos > self.baseline - 3.0) + + def test_pareto(self): + configs = self.tuner.take_best_configs() + for c1 in configs: + self.assertFalse( + any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs) + )