Skip to content
Snippets Groups Projects
Commit 4fa505e0 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added baseline knob

parent d66d2c43
No related branches found
No related tags found
No related merge requests found
"""Approximation techniques for torch.nn layers."""
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Type, Union
from typing import Dict, Iterable, Set, Type, Union
import torch
from torch.nn import Conv2d, Linear, Module, Parameter
from ..torchapp import TorchApproxKnob
from ..torchapp import BaselineKnob, TorchApproxKnob
from ._copy import module_only_deepcopy
PathLike = Union[Path, str]
......@@ -374,7 +374,7 @@ class FP16Approx(TorchApproxKnob):
default_name_to_class = {
k.__name__: k
for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling]
for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling, BaselineKnob]
}
default_knob_file = Path(__file__).parent / "default_approx_params.json"
......
[{
"class": "BaselineKnob",
"name": "11"
}, {
"class": "FP16Approx",
"name": "12",
"exp_speedup": 1.5
......
......@@ -7,8 +7,14 @@ from torch.nn import Module
from torch.utils.data.dataloader import DataLoader
from .approxapp import ApproxKnob, KnobsT
from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp,
QoSModelP1, QoSModelP2)
from .modeledapp import (
IPerfModel,
IQoSModel,
LinearPerfModel,
ModeledApp,
QoSModelP1,
QoSModelP2,
)
from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively
......@@ -63,7 +69,7 @@ class TorchApp(ModeledApp, abc.ABC):
self.module = module
self.val_loader = val_loader
self.test_loader = test_loader
self.name_to_knob = {k.name: k for k in knobs}
self.name_to_knob = {k.name: k for k in self._check_baseline_knob(knobs)}
self.tensor_to_qos = tensor_to_qos
self.combine_qos = combine_qos
self.device = device
......@@ -76,9 +82,11 @@ class TorchApp(ModeledApp, abc.ABC):
modules = self.midx.name_to_module
summary = get_summary(self.module, (self._sample_input(),))
for op_name, op in modules.items():
self._op_knobs[op_name] = [
op_knobs = [
knob for knob in self.name_to_knob.values() if knob.is_applicable(op)
]
assert op_knobs
self._op_knobs[op_name] = op_knobs
self._op_costs[op_name] = summary.loc[op_name, "flops"]
# Init parent class last
......@@ -142,6 +150,20 @@ class TorchApp(ModeledApp, abc.ABC):
all_outputs.append(outputs)
return torch.stack(all_outputs)
@staticmethod
def _check_baseline_knob(knobs: Set[TorchApproxKnob]) -> Set[TorchApproxKnob]:
last_baseline_knob = None
for k in knobs:
if not isinstance(k, BaselineKnob):
continue
if last_baseline_knob is None:
last_baseline_knob = k
else:
raise ValueError(f"Found more than 1 baseline knobs: {last_baseline_knob} and {k}")
if last_baseline_knob is None:
knobs.add(BaselineKnob())
return knobs
def _apply_knobs(self, knobs: KnobsT) -> Module:
import copy
......@@ -154,3 +176,22 @@ class TorchApp(ModeledApp, abc.ABC):
def _sample_input(self):
inputs, _ = next(iter(self.val_loader))
return inputs.to(self.device)
class BaselineKnob(TorchApproxKnob):
def __init__(self, name: str = "__baseline__"):
super().__init__(name)
@property
def deterministic(self) -> bool:
return True
@property
def expected_speedup(self) -> float:
return 1.0
def is_applicable(self, op: Module) -> bool:
return True
def apply(self, op: Module) -> Module:
return op
......@@ -3,6 +3,7 @@ import unittest
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
......@@ -24,3 +25,12 @@ class TestTorchAppInit(unittest.TestCase):
get_knobs_from_file(),
accuracy,
)
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d):
nknob = 63
elif isinstance(op, Linear):
nknob = 9
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment