From 544052ae8f39942f5413c6bb807de042a03171f9 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 28 Jan 2021 17:01:01 -0600
Subject: [PATCH] Improved checking for baseline knob

---
 predtuner/approxapp.py                        | 43 ++++++++++++++++---
 predtuner/approxes/approxes.py                |  4 +-
 predtuner/approxes/default_approx_params.json |  2 +-
 predtuner/modeledapp.py                       |  6 +--
 predtuner/pipedbin.py                         | 11 ++---
 predtuner/torchapp.py                         | 39 ++++++-----------
 test/test_torchapp.py                         |  5 ++-
 7 files changed, 63 insertions(+), 47 deletions(-)

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index 8e871df..61d203b 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -30,13 +30,18 @@ class ApproxKnob:
 
 class ApproxApp(abc.ABC):
     """Generic approximable application with operator & knob enumeration,
-    and measures its own QoS and performance given a configuration."""
-
-    @property
-    @abc.abstractmethod
-    def op_knobs(self) -> Dict[str, List[ApproxKnob]]:
-        """Get a mapping from each operator (identified by str) to a list of applicable knobs."""
-        pass
+    and measures its own QoS and performance given a configuration.
+    
+    Parameters
+    ----------
+    op_knobs:
+        a mapping from each operator (identified by str) to a list of applicable knobs.
+    """
+    def __init__(self, op_knobs: Dict[str, List[ApproxKnob]]) -> None:
+        super().__init__()
+        self.op_knobs = op_knobs
+        # Also modifies self.op_knobs in place.
+        self.baseline_knob = self._check_get_baseline_knob_(self.op_knobs)
 
     @abc.abstractmethod
     def measure_qos_perf(
@@ -65,6 +70,30 @@ class ApproxApp(abc.ABC):
         knob_sets = [set(knobs) for knobs in self.op_knobs.values()]
         return list(set.union(*knob_sets))
 
+    @staticmethod
+    def _check_get_baseline_knob_(op_knobs: Dict[str, List[ApproxKnob]]) -> "BaselineKnob":
+        # Modifies op_knobs inplace.
+        # Find the baseline knob if the user has one, or get a default one
+        knob_sets = [set(knobs) for knobs in op_knobs.values()]
+        knobs = list(set.union(*knob_sets))
+        baselines = set(k for k in knobs if isinstance(k, BaselineKnob))
+        if len(baselines) > 1:
+            raise ValueError(f"Found multiple baseline knobs in op_knobs: {baselines}")
+        if baselines:
+            baseline_knob, = baselines
+        else:
+            baseline_knob = BaselineKnob()
+        # Start checking if each layer has the baseline knob
+        for knobs in op_knobs.values():
+            if baseline_knob not in set(knobs):
+                knobs.append(baseline_knob)
+        return baseline_knob
+
+
+class BaselineKnob(ApproxKnob):
+    def __init__(self, name: str = "__baseline__"):
+        super().__init__(name)
+
 
 class Config:
     def __init__(
diff --git a/predtuner/approxes/approxes.py b/predtuner/approxes/approxes.py
index 4882522..d4ab809 100644
--- a/predtuner/approxes/approxes.py
+++ b/predtuner/approxes/approxes.py
@@ -6,7 +6,7 @@ import torch
 from torch.nn import Conv2d, Linear, Module, Parameter
 
 from .._logging import PathLike
-from ..torchapp import BaselineKnob, TorchApproxKnob
+from ..torchapp import TorchBaselineKnob, TorchApproxKnob
 from ._copy import module_only_deepcopy
 
 
@@ -373,7 +373,7 @@ class FP16Approx(TorchApproxKnob):
 
 default_name_to_class = {
     k.__name__: k
-    for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling, BaselineKnob]
+    for k in [FP16Approx, PromiseSim, PerforateConv2dStride, Conv2dSampling, TorchBaselineKnob]
 }
 default_knob_file = Path(__file__).parent / "default_approx_params.json"
 
diff --git a/predtuner/approxes/default_approx_params.json b/predtuner/approxes/default_approx_params.json
index 8e55440..8583c08 100644
--- a/predtuner/approxes/default_approx_params.json
+++ b/predtuner/approxes/default_approx_params.json
@@ -1,5 +1,5 @@
 [{
-    "class": "BaselineKnob",
+    "class": "TorchBaselineKnob",
     "name": "11"
 }, {
     "class": "FP16Approx",
diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py
index 6160c98..6c5fe70 100644
--- a/predtuner/modeledapp.py
+++ b/predtuner/modeledapp.py
@@ -9,8 +9,8 @@ import numpy as np
 import pandas as pd
 import torch
 
-from .approxapp import ApproxApp, ApproxTuner, Config, KnobsT
 from ._logging import PathLike
+from .approxapp import ApproxApp, ApproxKnob, ApproxTuner, Config, KnobsT
 
 msg_logger = logging.getLogger(__name__)
 
@@ -23,8 +23,8 @@ class ModeledApp(ApproxApp, abc.ABC):
     for non-modeling application, inherit from `ApproxApp` instead.
     """
 
-    def __init__(self) -> None:
-        super().__init__()
+    def __init__(self, op_knobs: Dict[str, List[ApproxKnob]]) -> None:
+        super().__init__(op_knobs)
         models = self.get_models()
         self._name_to_model = {m.name: m for m in models}
         if len(self._name_to_model) != len(models):
diff --git a/predtuner/pipedbin.py b/predtuner/pipedbin.py
index 91cc33d..113a14b 100644
--- a/predtuner/pipedbin.py
+++ b/predtuner/pipedbin.py
@@ -33,11 +33,12 @@ class PipedBinaryApp(ModeledApp):
             self._metadata = json.load(f)
         (
             self.op_costs,
-            self._op_knobs,
+            op_knobs,
             self.knob_speedups,
             self.baseline_knob,
         ) = self._check_metadata(self._metadata)
-        self._op_order = {v: i for i, v in enumerate(self._op_knobs.keys())}
+        super().__init__(op_knobs)  # Init here
+        self._op_order = {v: i for i, v in enumerate(op_knobs.keys())}
         self.tune_args = self._metadata["tune_args"]
         self.test_args = self._metadata["test_args"]
         self.binary_path = Path(binary_path)
@@ -49,7 +50,6 @@ class PipedBinaryApp(ModeledApp):
         )
         if not self.binary_path.is_file():
             raise RuntimeError(f"Binary file {self.binary_path} not found")
-        super().__init__()
 
         self.process = None
         self.fifo_path = Path(fifo_path)
@@ -63,11 +63,6 @@ class PipedBinaryApp(ModeledApp):
         the user should try to make it unique."""
         return self.app_name
 
-    @property
-    def op_knobs(self) -> Dict[str, List[ApproxKnob]]:
-        """Get a mapping from each operator (identified by str) to a list of applicable knobs."""
-        return self._op_knobs
-
     def measure_qos_perf(
         self, with_approxes: KnobsT, is_test: bool
     ) -> Tuple[float, float]:
diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py
index 3ad4f87..c93bd8a 100644
--- a/predtuner/torchapp.py
+++ b/predtuner/torchapp.py
@@ -8,7 +8,7 @@ from torch.nn import Module
 from torch.utils.data.dataloader import DataLoader
 
 from ._logging import PathLike
-from .approxapp import ApproxKnob, KnobsT
+from .approxapp import ApproxKnob, BaselineKnob, KnobsT
 from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp,
                          QoSModelP1, QoSModelP2)
 from .torchutil import ModuleIndexer, get_summary, move_to_device_recursively
@@ -110,31 +110,26 @@ class TorchApp(ModeledApp, abc.ABC):
         self.module = self.module.to(device)
         self.midx = ModuleIndexer(module)
         self._op_costs = {}
-        self._op_knobs = {}
+        op_knobs = {}
         self._knob_speedups = {k.name: k.expected_speedup for k in knobs}
         modules = self.midx.name_to_module
         summary = get_summary(self.module, (self._sample_input(),))
         for op_name, op in modules.items():
-            op_knobs = [
+            this_knobs = [
                 knob for knob in self.name_to_knob.values() if knob.is_applicable(op)
             ]
-            assert op_knobs
-            self._op_knobs[op_name] = op_knobs
+            assert this_knobs
+            op_knobs[op_name] = this_knobs
             self._op_costs[op_name] = summary.loc[op_name, "flops"]
 
         # Init parent class last
-        super().__init__()
+        super().__init__(op_knobs)
 
     @property
     def name(self) -> str:
         """Returns the name of application."""
         return self.app_name
 
-    @property
-    def op_knobs(self) -> Dict[str, List[ApproxKnob]]:
-        """Returns a list of applicable knobs for each operator (layer) in module."""
-        return self._op_knobs
-
     def get_models(self) -> List[Union[IPerfModel, IQoSModel]]:
         """Returns a list of predictive tuning models.
 
@@ -192,7 +187,7 @@ class TorchApp(ModeledApp, abc.ABC):
         module_class_name = type(self.module).__name__
         return (
             f'{class_name}"{self.name}"(module={module_class_name}, '
-            f"num_op={len(self._op_knobs)}, num_knob={len(self.name_to_knob)})"
+            f"num_op={len(self.op_knobs)}, num_knob={len(self.name_to_knob)})"
         )
 
     @torch.no_grad()
@@ -207,18 +202,12 @@ class TorchApp(ModeledApp, abc.ABC):
 
     @staticmethod
     def _check_baseline_knob(knobs: Set[TorchApproxKnob]) -> Set[TorchApproxKnob]:
-        last_baseline_knob = None
-        for k in knobs:
-            if not isinstance(k, BaselineKnob):
-                continue
-            if last_baseline_knob is None:
-                last_baseline_knob = k
-            else:
-                raise ValueError(
-                    f"Found more than 1 baseline knobs: {last_baseline_knob} and {k}"
-                )
-        if last_baseline_knob is None:
-            knobs.add(BaselineKnob())
+        baselines = set(k for k in knobs if isinstance(k, TorchBaselineKnob))
+        if len(baselines) > 1:
+            raise ValueError(f"Found multiple baseline knobs in op_knobs: {baselines}")
+        if not baselines:
+            print("Adding baseline knob to knob set")
+            knobs.add(TorchBaselineKnob())
         return knobs
 
     def _apply_knobs(self, knobs: KnobsT) -> Module:
@@ -235,7 +224,7 @@ class TorchApp(ModeledApp, abc.ABC):
         return inputs.to(self.device)
 
 
-class BaselineKnob(TorchApproxKnob):
+class TorchBaselineKnob(TorchApproxKnob, BaselineKnob):
     def __init__(self, name: str = "__baseline__"):
         super().__init__(name)
 
diff --git a/test/test_torchapp.py b/test/test_torchapp.py
index 0f58e07..74267ef 100644
--- a/test/test_torchapp.py
+++ b/test/test_torchapp.py
@@ -1,7 +1,7 @@
 import unittest
 
 import torch
-from model_zoo import CIFAR, VGG16Cifar10
+from predtuner.model_zoo import CIFAR, VGG16Cifar10
 from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
 from torch.nn import Conv2d, Linear
 from torch.utils.data.dataloader import DataLoader
@@ -42,6 +42,9 @@ class TestTorchAppTuning(TorchAppSetUp):
                 nknob = 1
             self.assertEqual(n_knobs[op_name], nknob)
 
+    def test_baseline_knob(self):
+        self.assertEqual(self.app.baseline_knob.name, "11")
+
     def test_baseline_qos(self):
         qos, _ = self.app.measure_qos_perf({}, False)
         self.assertAlmostEqual(qos, 88.0)
-- 
GitLab