Skip to content
Snippets Groups Projects
Commit 4fa505e0 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added baseline knob

parent d66d2c43
No related branches found
No related tags found
No related merge requests found
"""Approximation techniques for torch.nn layers.""" """Approximation techniques for torch.nn layers."""
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Type, Union from typing import Dict, Iterable, Set, Type, Union
import torch import torch
from torch.nn import Conv2d, Linear, Module, Parameter from torch.nn import Conv2d, Linear, Module, Parameter
from ..torchapp import TorchApproxKnob from ..torchapp import BaselineKnob, TorchApproxKnob
from ._copy import module_only_deepcopy from ._copy import module_only_deepcopy
PathLike = Union[Path, str] PathLike = Union[Path, str]
...@@ -374,7 +374,7 @@ class FP16Approx(TorchApproxKnob): ...@@ -374,7 +374,7 @@ class FP16Approx(TorchApproxKnob):
default_name_to_class = { default_name_to_class = {
k.__name__: k k.__name__: k
for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling] for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling, BaselineKnob]
} }
default_knob_file = Path(__file__).parent / "default_approx_params.json" default_knob_file = Path(__file__).parent / "default_approx_params.json"
......
[{ [{
"class": "BaselineKnob",
"name": "11"
}, {
"class": "FP16Approx", "class": "FP16Approx",
"name": "12", "name": "12",
"exp_speedup": 1.5 "exp_speedup": 1.5
......
...@@ -7,8 +7,14 @@ from torch.nn import Module ...@@ -7,8 +7,14 @@ from torch.nn import Module
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from .approxapp import ApproxKnob, KnobsT from .approxapp import ApproxKnob, KnobsT
from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp, from .modeledapp import (
QoSModelP1, QoSModelP2) IPerfModel,
IQoSModel,
LinearPerfModel,
ModeledApp,
QoSModelP1,
QoSModelP2,
)
from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively
...@@ -63,7 +69,7 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -63,7 +69,7 @@ class TorchApp(ModeledApp, abc.ABC):
self.module = module self.module = module
self.val_loader = val_loader self.val_loader = val_loader
self.test_loader = test_loader self.test_loader = test_loader
self.name_to_knob = {k.name: k for k in knobs} self.name_to_knob = {k.name: k for k in self._check_baseline_knob(knobs)}
self.tensor_to_qos = tensor_to_qos self.tensor_to_qos = tensor_to_qos
self.combine_qos = combine_qos self.combine_qos = combine_qos
self.device = device self.device = device
...@@ -76,9 +82,11 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -76,9 +82,11 @@ class TorchApp(ModeledApp, abc.ABC):
modules = self.midx.name_to_module modules = self.midx.name_to_module
summary = get_summary(self.module, (self._sample_input(),)) summary = get_summary(self.module, (self._sample_input(),))
for op_name, op in modules.items(): for op_name, op in modules.items():
self._op_knobs[op_name] = [ op_knobs = [
knob for knob in self.name_to_knob.values() if knob.is_applicable(op) knob for knob in self.name_to_knob.values() if knob.is_applicable(op)
] ]
assert op_knobs
self._op_knobs[op_name] = op_knobs
self._op_costs[op_name] = summary.loc[op_name, "flops"] self._op_costs[op_name] = summary.loc[op_name, "flops"]
# Init parent class last # Init parent class last
...@@ -142,6 +150,20 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -142,6 +150,20 @@ class TorchApp(ModeledApp, abc.ABC):
all_outputs.append(outputs) all_outputs.append(outputs)
return torch.stack(all_outputs) return torch.stack(all_outputs)
@staticmethod
def _check_baseline_knob(knobs: Set[TorchApproxKnob]) -> Set[TorchApproxKnob]:
last_baseline_knob = None
for k in knobs:
if not isinstance(k, BaselineKnob):
continue
if last_baseline_knob is None:
last_baseline_knob = k
else:
raise ValueError(f"Found more than 1 baseline knobs: {last_baseline_knob} and {k}")
if last_baseline_knob is None:
knobs.add(BaselineKnob())
return knobs
def _apply_knobs(self, knobs: KnobsT) -> Module: def _apply_knobs(self, knobs: KnobsT) -> Module:
import copy import copy
...@@ -154,3 +176,22 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -154,3 +176,22 @@ class TorchApp(ModeledApp, abc.ABC):
def _sample_input(self): def _sample_input(self):
inputs, _ = next(iter(self.val_loader)) inputs, _ = next(iter(self.val_loader))
return inputs.to(self.device) return inputs.to(self.device)
class BaselineKnob(TorchApproxKnob):
def __init__(self, name: str = "__baseline__"):
super().__init__(name)
@property
def deterministic(self) -> bool:
return True
@property
def expected_speedup(self) -> float:
return 1.0
def is_applicable(self, op: Module) -> bool:
return True
def apply(self, op: Module) -> Module:
return op
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
from predtuner.approxes import get_knobs_from_file from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy from predtuner.torchutil import accuracy
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -24,3 +25,12 @@ class TestTorchAppInit(unittest.TestCase): ...@@ -24,3 +25,12 @@ class TestTorchAppInit(unittest.TestCase):
get_knobs_from_file(), get_knobs_from_file(),
accuracy, accuracy,
) )
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d):
nknob = 63
elif isinstance(op, Linear):
nknob = 9
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment