diff --git a/predtuner/approxes/approxes.py b/predtuner/approxes/approxes.py index 22de3534c31ad34de786a0c912d5946100a0b1bd..25aaf9221c5f0463b870f63cea497cf2039508f0 100644 --- a/predtuner/approxes/approxes.py +++ b/predtuner/approxes/approxes.py @@ -1,11 +1,11 @@ """Approximation techniques for torch.nn layers.""" from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, Type, Union +from typing import Dict, Iterable, Set, Type, Union import torch from torch.nn import Conv2d, Linear, Module, Parameter -from ..torchapp import TorchApproxKnob +from ..torchapp import BaselineKnob, TorchApproxKnob from ._copy import module_only_deepcopy PathLike = Union[Path, str] @@ -374,7 +374,7 @@ class FP16Approx(TorchApproxKnob): default_name_to_class = { k.__name__: k - for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling] + for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling, BaselineKnob] } default_knob_file = Path(__file__).parent / "default_approx_params.json" diff --git a/predtuner/approxes/default_approx_params.json b/predtuner/approxes/default_approx_params.json index 5cb5fb53c972746e45f262c44e2a9d21f7b1fa7f..9c8833cd19dca83bff192f75e5b12b08f7da9257 100644 --- a/predtuner/approxes/default_approx_params.json +++ b/predtuner/approxes/default_approx_params.json @@ -1,4 +1,7 @@ [{ + "class": "BaselineKnob", + "name": "11" +}, { "class": "FP16Approx", "name": "12", "exp_speedup": 1.5 diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index 6bc700b17c634f3a3a3530f632ca800c93a3a1a3..826de35329d4a3b9c5b77aee3d329a5f1acdbc9b 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -7,8 +7,14 @@ from torch.nn import Module from torch.utils.data.dataloader import DataLoader from .approxapp import ApproxKnob, KnobsT -from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp, - QoSModelP1, QoSModelP2) +from .modeledapp import ( + IPerfModel, + IQoSModel, + LinearPerfModel, + ModeledApp, + QoSModelP1, + QoSModelP2, +) from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively @@ -63,7 +69,7 @@ class TorchApp(ModeledApp, abc.ABC): self.module = module self.val_loader = val_loader self.test_loader = test_loader - self.name_to_knob = {k.name: k for k in knobs} + self.name_to_knob = {k.name: k for k in self._check_baseline_knob(knobs)} self.tensor_to_qos = tensor_to_qos self.combine_qos = combine_qos self.device = device @@ -76,9 +82,11 @@ class TorchApp(ModeledApp, abc.ABC): modules = self.midx.name_to_module summary = get_summary(self.module, (self._sample_input(),)) for op_name, op in modules.items(): - self._op_knobs[op_name] = [ + op_knobs = [ knob for knob in self.name_to_knob.values() if knob.is_applicable(op) ] + assert op_knobs + self._op_knobs[op_name] = op_knobs self._op_costs[op_name] = summary.loc[op_name, "flops"] # Init parent class last @@ -142,6 +150,20 @@ class TorchApp(ModeledApp, abc.ABC): all_outputs.append(outputs) return torch.stack(all_outputs) + @staticmethod + def _check_baseline_knob(knobs: Set[TorchApproxKnob]) -> Set[TorchApproxKnob]: + last_baseline_knob = None + for k in knobs: + if not isinstance(k, BaselineKnob): + continue + if last_baseline_knob is None: + last_baseline_knob = k + else: + raise ValueError(f"Found more than 1 baseline knobs: {last_baseline_knob} and {k}") + if last_baseline_knob is None: + knobs.add(BaselineKnob()) + return knobs + def _apply_knobs(self, knobs: KnobsT) -> Module: import copy @@ -154,3 +176,22 @@ class TorchApp(ModeledApp, abc.ABC): def _sample_input(self): inputs, _ = next(iter(self.val_loader)) return inputs.to(self.device) + + +class BaselineKnob(TorchApproxKnob): + def __init__(self, name: str = "__baseline__"): + super().__init__(name) + + @property + def deterministic(self) -> bool: + return True + + @property + def expected_speedup(self) -> float: + return 1.0 + + def is_applicable(self, op: Module) -> bool: + return True + + def apply(self, op: Module) -> Module: + return op diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 094c5ddebb14ac49f3788a45ad0a6303b56af056..2f3514a89517dc852a197b31167c2c103c50c84e 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -3,6 +3,7 @@ import unittest from predtuner.approxes import get_knobs_from_file from predtuner.torchapp import TorchApp from predtuner.torchutil import accuracy +from torch.nn import Conv2d, Linear from torch.utils.data.dataloader import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -24,3 +25,12 @@ class TestTorchAppInit(unittest.TestCase): get_knobs_from_file(), accuracy, ) + n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()} + for op_name, op in app.midx.name_to_module.items(): + if isinstance(op, Conv2d): + nknob = 63 + elif isinstance(op, Linear): + nknob = 9 + else: + nknob = 1 + self.assertEqual(n_knobs[op_name], nknob)