diff --git a/predtuner/approxes/approxes.py b/predtuner/approxes/approxes.py index 6f4c9c361f367e28a1c092e9cd35235021636809..22de3534c31ad34de786a0c912d5946100a0b1bd 100644 --- a/predtuner/approxes/approxes.py +++ b/predtuner/approxes/approxes.py @@ -352,14 +352,11 @@ class FP16Approx(TorchApproxKnob): return True @property - def applicable_op_types(self) -> List[Type[Module]]: - return [Conv2d, Linear] - def expected_speedup(self) -> float: return self.exp_speedup - def is_less_approx(self, other: TorchApproxKnob) -> Optional[bool]: - return None + def is_applicable(self, op: Module) -> bool: + return isinstance(op, (Conv2d, Linear)) class FP16ApproxModule(Module): def __init__(self, module: Module): diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index e7291edfa34481a6be1bd49e599c035865350997..b30968b8b186dcc4402b8202fa288a971b996db4 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Callable, List, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Set, Tuple, Union import numpy as np import torch @@ -50,6 +50,7 @@ class TorchApp(ModeledApp, abc.ABC): def __init__( self, + app_name: str, module: Module, val_loader: DataLoader, test_loader: DataLoader, @@ -58,7 +59,7 @@ class TorchApp(ModeledApp, abc.ABC): combine_qos: Callable[[np.ndarray], float] = np.mean, device: Union[torch.device, str] = _default_device, ) -> None: - super().__init__() + self.app_name = app_name self.module = module self.val_loader = val_loader self.test_loader = test_loader @@ -67,6 +68,7 @@ class TorchApp(ModeledApp, abc.ABC): self.combine_qos = combine_qos self.device = device + self.module = self.module.to(device) self.midx = ModuleIndexer(module) self._op_costs = {} self._op_knobs = {} @@ -79,6 +81,17 @@ class TorchApp(ModeledApp, abc.ABC): ] self._op_costs[op_name] = summary.loc[op_name, "flops"] + # Init parent class last + super().__init__() + + @property + def name(self) -> str: + return self.app_name + + @property + def op_knobs(self) -> Dict[str, List[ApproxKnob]]: + return self._op_knobs + def get_models(self) -> List[Union[IPerfModel, IQoSModel]]: def batched_valset_qos(tensor_output: torch.Tensor): dataset_len = len(self.val_loader.dataset) diff --git a/predtuner/torchutil/__init__.py b/predtuner/torchutil/__init__.py index a5a0b842d614a4ca830325e48a660aab5fa37765..af018f0d83469755f5092684c19a42aa3089184f 100644 --- a/predtuner/torchutil/__init__.py +++ b/predtuner/torchutil/__init__.py @@ -1,3 +1,4 @@ +from .common_qos import accuracy from .indexing import ModuleIndexer from .summary import get_summary from .utils import (BatchedDataLoader, infer_net_device, diff --git a/predtuner/torchutil/common_qos.py b/predtuner/torchutil/common_qos.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c9ccdbcd2e9917c597e42ec4aae0181086f004 --- /dev/null +++ b/predtuner/torchutil/common_qos.py @@ -0,0 +1,7 @@ +from torch import Tensor + + +def accuracy(output: Tensor, target: Tensor) -> float: + _, pred_labels = output.max(1) + n_correct = (pred_labels == target).sum().item() + return n_correct / len(output) diff --git a/predtuner/torchutil/summary.py b/predtuner/torchutil/summary.py index 25d346f98f19f64e61cda3346b0c467cb1252b84..b6e150db92f09aa7cd93ed363fa9ca325ee49f2e 100644 --- a/predtuner/torchutil/summary.py +++ b/predtuner/torchutil/summary.py @@ -51,12 +51,12 @@ def get_flops(module: nn.Module, input_shape, output_shape): if not handler: if not list(module.children()): _print_once(f"Leaf module {module} cannot be handled") - return None + return 0.0 try: return handler() except RuntimeError as e: _print_once(f'Error "{e}" when handling {module}') - return None + return 0.0 def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame: diff --git a/test/test_torchapp.py b/test/test_torchapp.py new file mode 100644 index 0000000000000000000000000000000000000000..094c5ddebb14ac49f3788a45ad0a6303b56af056 --- /dev/null +++ b/test/test_torchapp.py @@ -0,0 +1,26 @@ +import unittest + +from predtuner.approxes import get_knobs_from_file +from predtuner.torchapp import TorchApp +from predtuner.torchutil import accuracy +from torch.utils.data.dataloader import DataLoader +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models.vgg import vgg16 + + +class TestTorchAppInit(unittest.TestCase): + def setUp(self): + transform = transforms.Compose([transforms.ToTensor()]) + self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform) + self.module = vgg16(pretrained=True) + + def test_init(self): + app = TorchApp( + "test", + self.module, + DataLoader(self.dataset), + DataLoader(self.dataset), + get_knobs_from_file(), + accuracy, + )