From 007997ee8c5c57ba48fdb806d6b43ec623cf0989 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Fri, 22 Jan 2021 16:35:21 -0600
Subject: [PATCH] Filled in API level 3

---
 predtuner/apps/modeledapp.py    |  19 +++--
 predtuner/apps/torchapp.py      | 134 +++++++++++++++++++++++++++-----
 predtuner/torchutil/__init__.py |   4 +
 predtuner/torchutil/indexing.py | 129 ++++++++++++++++++++++++++++++
 predtuner/torchutil/summary.py  |  77 ++++++++++++++++++
 predtuner/torchutil/utils.py    |  77 ++++++++++++++++++
 6 files changed, 413 insertions(+), 27 deletions(-)
 create mode 100644 predtuner/torchutil/__init__.py
 create mode 100644 predtuner/torchutil/indexing.py
 create mode 100644 predtuner/torchutil/summary.py
 create mode 100644 predtuner/torchutil/utils.py

diff --git a/predtuner/apps/modeledapp.py b/predtuner/apps/modeledapp.py
index 0001b4f..4c01fa5 100644
--- a/predtuner/apps/modeledapp.py
+++ b/predtuner/apps/modeledapp.py
@@ -50,6 +50,9 @@ class ModeledApp(ApproxApp, abc.ABC):
         Empirical measurement will be called once if either `perf_model` or `qos_model`
         is "none", otherwise only use model indicated by model name.
         """
+        # Testset measurement is always empirical
+        if is_testset:
+            return self.empirical_measure_qos_perf(with_approxes, is_testset)
         # Run empirical measurement once if either perf or qos needs it
         qos, perf = None, None
         if qos_model == "none" or perf_model == "none":
@@ -85,7 +88,7 @@ class IPerfModel(abc.ABC):
         pass
 
     @abc.abstractmethod
-    def measure_perf(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def measure_perf(self, with_approxes: KnobsT) -> float:
         """We implement this using a weighted linear performance model."""
         pass
 
@@ -100,7 +103,7 @@ class IQoSModel(abc.ABC):
         pass
 
     @abc.abstractmethod
-    def measure_qos(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def measure_qos(self, with_approxes: KnobsT) -> float:
         """We implement this using a weighted linear performance model."""
         pass
 
@@ -126,7 +129,7 @@ class LinearPerfModel(IPerfModel):
     def name(self) -> str:
         return "perf_linear"
 
-    def measure_perf(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def measure_perf(self, with_approxes: KnobsT) -> float:
         """We implement this using a weighted linear performance model."""
         return sum(
             self.cost_df.loc[layer, knob] for layer, knob in with_approxes.items()
@@ -147,7 +150,7 @@ class QoSModelP1(IQoSModel):
 
     def __init__(
         self,
-        tensor_output_getter: Callable[[KnobsT, bool], torch.Tensor],
+        tensor_output_getter: Callable[[KnobsT], torch.Tensor],
         qos_metric: Callable[[torch.Tensor], float],
     ) -> None:
         super().__init__()
@@ -158,7 +161,7 @@ class QoSModelP1(IQoSModel):
     def name(self) -> str:
         return "qos_p1"
 
-    def measure_qos(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def measure_qos(self, with_approxes: KnobsT) -> float:
         """Implementation of model."""
         pass
 
@@ -174,16 +177,16 @@ class QoSModelP2(IQoSModel):
     def name(self) -> str:
         return "qos_p2"
 
-    def _empirical_measure_qos(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def _empirical_measure_qos(self, with_approxes: KnobsT) -> float:
         """An internal QoS-measuring method.
 
         The point is P2 queries some QoS results and caches them before tuning starts,
         and then defines a `measure_qos` that doesn't run the application during tuning
         (to reduce overhead).
         """
-        qos, _ = self.app.empirical_measure_qos_perf(with_approxes, is_testset)
+        qos, _ = self.app.empirical_measure_qos_perf(with_approxes, False)
         return qos
 
-    def measure_qos(self, with_approxes: KnobsT, is_testset: bool) -> float:
+    def measure_qos(self, with_approxes: KnobsT) -> float:
         """Implementation of model."""
         pass
diff --git a/predtuner/apps/torchapp.py b/predtuner/apps/torchapp.py
index da2f3c6..261e4b6 100644
--- a/predtuner/apps/torchapp.py
+++ b/predtuner/apps/torchapp.py
@@ -1,10 +1,15 @@
 import abc
-from typing import Set
+from typing import Any, Callable, List, Set, Tuple, Union
 
+import numpy as np
+import torch
+from torch.nn import Module
 from torch.utils.data.dataloader import DataLoader
 
-from .approxapp import ApproxKnob
-from .modeledapp import IPerfModeled, IQoSModeledP1, IQoSModeledP2, ModeledApp
+from ..torchutil import ModuleIndexer, get_summary, move_to_device_recursively
+from .approxapp import ApproxKnob, KnobsT
+from .modeledapp import (IPerfModel, IQoSModel, LinearPerfModel, ModeledApp,
+                         QoSModelP1, QoSModelP2)
 
 
 class TorchApproxKnob(ApproxKnob):
@@ -12,28 +17,119 @@ class TorchApproxKnob(ApproxKnob):
     its own expected speedup ratio and what Modules it can apply to,
     and can be applied to a torch.nn.Module to return an approximated Module."""
 
-    pass
-
-
-class TorchApp(ModeledApp, IPerfModeled, IQoSModeledP1, IQoSModeledP2, abc.ABC):
-    """Approximable PyTorch Modules (tensor output assumed).
-  
-    Automatically derives performance model and QoS models P1&P2."""
+    @property
+    @abc.abstractmethod
+    def deterministic(self) -> bool:
+        """Returns true if approx knob does not contain randomness."""
+        pass
 
     @property
     @abc.abstractmethod
-    def all_knobs(self) -> Set[TorchApproxKnob]:
-        """User defines a set of all knobs available; we'll dispatch them to each layer (op)."""
+    def expected_speedup(self) -> float:
         pass
 
     @abc.abstractmethod
-    def get_input_data(self, testset: bool) -> DataLoader:
-        """User defines the input dataset to traverse."""
+    def is_applicable(self, op: Module) -> bool:
         pass
 
-    # User also needs to define `IQoSModeledP1.qos_from_output` (QoS metric, omitted)
+    @abc.abstractmethod
+    def apply(self, op: Module) -> Module:
+        """Applies knob to `module` and returns an approximated `module`."""
+        pass
+
+
+_default_device = f"cuda" if torch.cuda.is_available() else "cpu"
+
+
+class TorchApp(ModeledApp, abc.ABC):
+    """Approximable PyTorch Modules (tensor output assumed).
+    Automatically derives performance model and QoS models P1&P2.
+    
+    knobs: User defines a set of all knobs available; we'll dispatch them to each layer (op).
+    """
+
+    def __init__(
+        self,
+        module: Module,
+        val_loader: DataLoader,
+        test_loader: DataLoader,
+        knobs: Set[TorchApproxKnob],
+        tensor_to_qos: Callable[[torch.Tensor, Any], float],
+        combine_qos: Callable[[np.ndarray], float] = np.mean,
+        device: Union[torch.device, str] = _default_device,
+    ) -> None:
+        super().__init__()
+        self.module = module
+        self.val_loader = val_loader
+        self.test_loader = test_loader
+        self.name_to_knob = {k.name: k for k in knobs}
+        self.tensor_to_qos = tensor_to_qos
+        self.combine_qos = combine_qos
+        self.device = device
+
+        self.midx = ModuleIndexer(module)
+        self._op_costs = {}
+        self._op_knobs = {}
+        self._knob_speedups = {k.name: k.expected_speedup for k in knobs}
+        modules = self.midx.name_to_module
+        summary = get_summary(self.module, (self._sample_input(),))
+        for op_name, op in modules.items():
+            self._op_knobs[op_name] = [
+                knob for knob in self.name_to_knob.values() if knob.is_applicable(op)
+            ]
+            self._op_costs[op_name] = summary.loc[op_name, "flops"]
+
+    def get_models(self) -> List[Union[IPerfModel, IQoSModel]]:
+        def batched_valset_qos(tensor_output: torch.Tensor):
+            dataset_len = len(self.val_loader.dataset)
+            assert len(tensor_output) == dataset_len
+            begin = 0
+            qoses = []
+            for _, target in self.val_loader:
+                end = begin + len(target)
+                qos = self.tensor_to_qos(tensor_output[begin:end], target)
+                qoses.append(qos)
+            return self.combine_qos(np.array(qoses))
+
+        return [
+            LinearPerfModel(self._op_costs, self._knob_speedups),
+            QoSModelP1(self._get_raw_output_valset, batched_valset_qos),
+            QoSModelP2(self),
+        ]
+
+    @torch.no_grad()
+    def empirical_measure_qos_perf(
+        self, with_approxes: KnobsT, is_testset: bool
+    ) -> Tuple[float, float]:
+        dataloader = self.test_loader if is_testset else self.val_loader
+        approxed = self._apply_knobs(with_approxes)
+        qoses = []
+        for inputs, targets in dataloader:
+            inputs = move_to_device_recursively(inputs, self.device)
+            outputs = approxed(inputs)
+            qoses.append(self.tensor_to_qos(outputs, targets))
+        qos = self.combine_qos(np.array(qoses))
+        return 0.0, qos
+
+    @torch.no_grad()
+    def _get_raw_output_valset(self, with_approxes: KnobsT):
+        approxed = self._apply_knobs(with_approxes)
+        all_outputs = []
+        for inputs, _ in self.val_loader:
+            inputs = move_to_device_recursively(inputs, self.device)
+            outputs = approxed(inputs)
+            all_outputs.append(outputs)
+        return torch.stack(all_outputs)
+
+    def _apply_knobs(self, knobs: KnobsT) -> Module:
+        import copy
+
+        module_indexer = copy.deepcopy(self.midx)
+        for op_name, knob_name in knobs.items():
+            knob = self.name_to_knob[knob_name]
+            module_indexer[op_name] = knob.apply(module_indexer[op_name])
+        return module_indexer.module
 
-    # We implement `ApproxApp.op_knobs`,
-    # `IPerfModeled.op_knobs_cost`,
-    # `IQoSModeledP1.get_tensor_output`
-    # and `IQoSModeledP2._measure_qos`. (Omitted)
+    def _sample_input(self):
+        inputs, _ = next(iter(self.val_loader))
+        return inputs.to(self.device)
diff --git a/predtuner/torchutil/__init__.py b/predtuner/torchutil/__init__.py
new file mode 100644
index 0000000..a5a0b84
--- /dev/null
+++ b/predtuner/torchutil/__init__.py
@@ -0,0 +1,4 @@
+from .indexing import ModuleIndexer
+from .summary import get_summary
+from .utils import (BatchedDataLoader, infer_net_device,
+                    move_to_device_recursively, split_dataset)
diff --git a/predtuner/torchutil/indexing.py b/predtuner/torchutil/indexing.py
new file mode 100644
index 0000000..a059ad5
--- /dev/null
+++ b/predtuner/torchutil/indexing.py
@@ -0,0 +1,129 @@
+"""Tools for indexing into an nn.Module with layer name (str) or index (int)."""
+from typing import Callable, Dict, Iterator, Optional, Set, Tuple, Union
+
+from torch.nn import Module, Sequential
+
+ModulePredT = Callable[[Module], bool]
+
+
+class ModuleIndexer:
+    r"""Allows indexing into an nn.Module with index (int) to get layers.
+        Supports read and modification, just like a dictionary.
+
+        Parameters
+        ----------
+        module: Module
+            The PyTorch Module to be indexed.
+        include_module: Callable[[Module], bool] = None
+            A predicate that decides which layers to include in the index. For example,
+            `lambda layer: isinstance(layer, Conv2d)` tells `ModuleIndexer` to only include `Conv2d`
+            layers.
+            If not given, by default `ModuleIndexer` will recursively walk down `module` like a tree
+            to include all internal and leaf nodes (layers), except for layers that `expand_module`
+            forbids recursing into.
+        expand_module: Callable[[Module], bool] = None
+            A predicate that decides which layers to recurse down. If `expand_module` returns `False`,
+            layer is kept as a whole and may be included if `include_module` allows.
+
+        Attributes
+        ----------
+        module: Module
+            Equal to parameter `module`.
+        index_to_module: List[Module]
+            Stores the layers in order so that a layer at `index_to_module[i]` has the index `i`.
+        layer_parent: Dict[Module, Tuple[Module, str]]
+            Maps each layer to its parent and its name in the parent layer. Contains the same layers
+            as in `index_to_module` except `module` which has no parent.
+        """
+
+    def __init__(
+            self, module: Module, include_module: Optional[ModulePredT] = None,
+            expand_module: Optional[ModulePredT] = None
+    ):
+        self.module = module
+        self.index_to_module = []
+        self.module_to_name = {}
+        self.name_to_index = {}
+        # By default, don't include container layer, and don't include (empty) Sequential
+        has_children = lambda m: bool(list(m.children()))
+        default_inclusion = lambda m: not has_children(m) and not isinstance(m, Sequential)
+        # No need for "default expansion" because whatever is not included will be walked into.
+        self._rec_expand_module(
+            module, '', include_module or default_inclusion, expand_module
+        )
+        self.layer_parent = self._find_layers_parent_info(module, set(self.all_modules))
+
+    def _rec_expand_module(
+            self, module: Module, name_prefix: str,
+            include_module: ModulePredT, expand_module: Optional[ModulePredT]
+    ):
+        """Recursively expands into module and builds the index."""
+        for name, submodule in module.named_children():
+            full_name = f"{name_prefix}.{name}" if name_prefix else name
+            included = include_module(submodule)
+            if included:
+                self.index_to_module.append(submodule)
+                self.module_to_name[submodule] = full_name
+                self.name_to_index[full_name] = len(self.index_to_module) - 1
+            required_expansion = expand_module and expand_module(submodule)
+            default_expansion = not included
+            if default_expansion or required_expansion:
+                self._rec_expand_module(submodule, full_name, include_module, expand_module)
+
+    @staticmethod
+    def _find_layers_parent_info(net: Module, layers: Set[Module]):
+        """Find parent info for each child layer in `net`, ignoring those not in `layers`."""
+        ret = {}
+        for name, submodule in net.named_children():
+            if submodule in layers:
+                ret[submodule] = net, name
+            ret = {**ret, **ModuleIndexer._find_layers_parent_info(submodule, layers)}
+        return ret
+
+    @property
+    def all_modules(self) -> Iterator[Module]:
+        return iter(self.index_to_module)
+
+    @property
+    def name_to_module(self) -> Dict[str, Module]:
+        return {name: self.index_to_module[index] for name, index in self.name_to_index.items()}
+
+    def find_by_module(self, module: Module) -> Optional[Tuple[str, int]]:
+        """Get name and index from module."""
+        name = self.module_to_name.get(module, None)
+        if name is None:
+            return None
+        index = self.name_to_index[name]
+        return name, index
+
+    def __getitem__(self, item: Union[int, str]) -> Module:
+        """Get module from index."""
+        if isinstance(item, int):
+            return self.index_to_module[item]
+        elif isinstance(item, str):
+            return self[self.name_to_index[item]]
+        raise KeyError(f"Key type {item.__class__} not understood")
+
+    def __setitem__(self, key: Union[int, str], value: Module):
+        """Swap in the layer at index `key` to be `value`.
+
+        The parent of the old layer at `key` is also updated with the new layer, so that `self.module`
+        has the old layer replaced with new.
+        """
+        if isinstance(key, str):
+            key = self.name_to_index[key]
+        old = self.index_to_module[key]
+        if value != old:
+            self.index_to_module[key] = value
+            self.module_to_name[value] = self.module_to_name.pop(old)
+            parent, name = self.layer_parent[old]
+            self.layer_parent[value] = parent, name
+            self.layer_parent.pop(old)
+            parent.__setattr__(name, value)
+
+    def __iter__(self) -> Iterator[Module]:
+        return self.all_modules
+
+    def __len__(self):
+        """Number of indexed layers."""
+        return len(self.index_to_module)
diff --git a/predtuner/torchutil/summary.py b/predtuner/torchutil/summary.py
new file mode 100644
index 0000000..7491fa3
--- /dev/null
+++ b/predtuner/torchutil/summary.py
@@ -0,0 +1,77 @@
+from collections import OrderedDict
+from typing import Tuple
+
+import pandas
+import torch
+import torch.nn as nn
+
+from .indexing import ModuleIndexer
+
+
+def get_flops(module: nn.Module, input_shape, output_shape):
+    if output_shape is None:
+        return None
+    n_elem = torch.prod(torch.tensor(output_shape)).item()
+    if isinstance(module, nn.Linear):
+        if input_shape is None:
+            return None
+        _, n = input_shape
+        k, n_ = module.weight.shape
+        assert n == n_
+        return n * n * k
+    if isinstance(module, nn.Conv2d):
+        _, _, h, w = output_shape
+        return module.weight.numel() * h * w
+    if isinstance(module, nn.BatchNorm2d):
+        return 6 * n_elem
+    return None
+
+
+def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
+    include = lambda m: (
+            not isinstance(m, nn.Sequential) and not isinstance(m, nn.ModuleList) and not (m == model)
+    )
+    indexed = ModuleIndexer(model, include, lambda m: True)
+    find_by_module = lambda m: indexed.find_by_module(m)[0]
+    summary = OrderedDict()
+    hooks = []
+
+    def hook(module: nn.Module, inputs, outputs):
+        module_name = find_by_module(module)
+
+        try:
+            input_shape = list(inputs[0].size())
+        except AttributeError:
+            input_shape = None
+        try:
+            if isinstance(outputs, (list, tuple)):
+                output_shape = [[-1] + list(o.size())[1:] for o in outputs]
+            else:
+                output_shape = list(outputs.size())
+        except AttributeError:
+            output_shape = None
+
+        n_params = sum(param.numel() for param in module.parameters())
+        trainable = any(param.requires_grad for param in module.parameters())
+
+        summary[module_name] = OrderedDict(
+            type=module.__class__.__name__,
+            input_shape=input_shape,
+            output_shape=output_shape,
+            params=n_params,
+            flops=get_flops(module, input_shape, output_shape),
+            trainable=trainable
+        )
+
+    def register_hook(module: nn.Module):
+        if include(module):
+            hooks.append(module.register_forward_hook(hook))
+
+    # register hook
+    model.apply(register_hook)
+    with torch.no_grad():
+        model(*model_args)
+    # remove these hooks
+    for h in hooks:
+        h.remove()
+    return pandas.DataFrame(summary)
diff --git a/predtuner/torchutil/utils.py b/predtuner/torchutil/utils.py
new file mode 100644
index 0000000..4fb6564
--- /dev/null
+++ b/predtuner/torchutil/utils.py
@@ -0,0 +1,77 @@
+from typing import Optional, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+from torch.utils.data import DataLoader, Dataset, Subset
+from torch.utils.data._utils.fetch import _BaseDatasetFetcher
+from torch.utils.data.dataloader import _SingleProcessDataLoaderIter
+
+
+def infer_net_device(net: Module):
+    """Guess the device `net` is on.
+
+    This assumes its all parts are on the same device, and takes the device of any parameter.
+    This function does not check the device of buffers, etc. in `net`."""
+    devices = set(pm.device for pm in net.parameters())
+    if len(devices) == 0:
+        raise RuntimeError("Cannot infer device for net with no parameters")
+    if len(devices) > 1:
+        raise RuntimeError("Parts of the network are on different devices")
+    (device,) = devices
+    return device
+
+
+def move_to_device_recursively(data: object, device: Union[torch.device, str]):
+    """Move all Tensors in `data` recursively to `device`."""
+    if isinstance(data, Tensor):
+        return data.to(device)
+    if not hasattr(data, "__dict__"):
+        if isinstance(data, list):
+            return [move_to_device_recursively(x, device) for x in data]
+        elif isinstance(data, tuple):
+            return tuple([move_to_device_recursively(x, device) for x in data])
+        else:
+            raise RuntimeError(f"Don't know how to manipulate {type(data)}")
+    for key, value in data.__dict__.items():
+        data.__dict__[key] = move_to_device_recursively(value, device)
+    return data
+
+
+def split_dataset(dataset: Dataset, split_at: int):
+    return (
+        Subset(dataset, torch.arange(0, split_at)),
+        Subset(dataset, torch.arange(split_at, len(dataset))),
+    )
+
+
+class BatchedDataLoader(DataLoader):
+    """Faster data loader for datasets that supports batch indexing.
+
+    Some datasets load the whole Tensor into memory and can be indexed by a batch of indices,
+    instead of indexed one by one and stacking the data together (which is what DataLoader does).
+    `BatchedDataLoader` instead uses `_BatchedMapDatasetFetcher` to batch index the dataset,
+    removing some overhead.
+    """
+
+    def __init__(self, dataset: Dataset, batch_size: Optional[int], *args, **kwargs):
+        super().__init__(dataset, batch_size=batch_size, *args, **kwargs)
+        try:
+            next(iter(self))
+            self.support_batch = True
+        except (KeyError, ValueError, RuntimeError):
+            self.support_batch = False
+
+    def __iter__(self):
+        if self.num_workers == 0 and self.support_batch:
+            dl_iter = _SingleProcessDataLoaderIter(self)
+            dl_iter._dataset_fetcher = _BatchedMapDatasetFetcher(
+                self.dataset, self._auto_collation, self.collate_fn, self.drop_last
+            )
+            return dl_iter
+        return super(BatchedDataLoader, self).__iter__()
+
+
+class _BatchedMapDatasetFetcher(_BaseDatasetFetcher):
+    def fetch(self, possibly_batched_index):
+        return self.dataset[possibly_batched_index]
-- 
GitLab