Skip to content
Snippets Groups Projects
Commit 92a20e60 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

TorchApp passed initialization test

parent ceaca2c9
No related branches found
No related tags found
No related merge requests found
...@@ -352,14 +352,11 @@ class FP16Approx(TorchApproxKnob): ...@@ -352,14 +352,11 @@ class FP16Approx(TorchApproxKnob):
return True return True
@property @property
def applicable_op_types(self) -> List[Type[Module]]:
return [Conv2d, Linear]
def expected_speedup(self) -> float: def expected_speedup(self) -> float:
return self.exp_speedup return self.exp_speedup
def is_less_approx(self, other: TorchApproxKnob) -> Optional[bool]: def is_applicable(self, op: Module) -> bool:
return None return isinstance(op, (Conv2d, Linear))
class FP16ApproxModule(Module): class FP16ApproxModule(Module):
def __init__(self, module: Module): def __init__(self, module: Module):
......
import abc import abc
from typing import Any, Callable, List, Set, Tuple, Union from typing import Any, Callable, Dict, List, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -50,6 +50,7 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -50,6 +50,7 @@ class TorchApp(ModeledApp, abc.ABC):
def __init__( def __init__(
self, self,
app_name: str,
module: Module, module: Module,
val_loader: DataLoader, val_loader: DataLoader,
test_loader: DataLoader, test_loader: DataLoader,
...@@ -58,7 +59,7 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -58,7 +59,7 @@ class TorchApp(ModeledApp, abc.ABC):
combine_qos: Callable[[np.ndarray], float] = np.mean, combine_qos: Callable[[np.ndarray], float] = np.mean,
device: Union[torch.device, str] = _default_device, device: Union[torch.device, str] = _default_device,
) -> None: ) -> None:
super().__init__() self.app_name = app_name
self.module = module self.module = module
self.val_loader = val_loader self.val_loader = val_loader
self.test_loader = test_loader self.test_loader = test_loader
...@@ -67,6 +68,7 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -67,6 +68,7 @@ class TorchApp(ModeledApp, abc.ABC):
self.combine_qos = combine_qos self.combine_qos = combine_qos
self.device = device self.device = device
self.module = self.module.to(device)
self.midx = ModuleIndexer(module) self.midx = ModuleIndexer(module)
self._op_costs = {} self._op_costs = {}
self._op_knobs = {} self._op_knobs = {}
...@@ -79,6 +81,17 @@ class TorchApp(ModeledApp, abc.ABC): ...@@ -79,6 +81,17 @@ class TorchApp(ModeledApp, abc.ABC):
] ]
self._op_costs[op_name] = summary.loc[op_name, "flops"] self._op_costs[op_name] = summary.loc[op_name, "flops"]
# Init parent class last
super().__init__()
@property
def name(self) -> str:
return self.app_name
@property
def op_knobs(self) -> Dict[str, List[ApproxKnob]]:
return self._op_knobs
def get_models(self) -> List[Union[IPerfModel, IQoSModel]]: def get_models(self) -> List[Union[IPerfModel, IQoSModel]]:
def batched_valset_qos(tensor_output: torch.Tensor): def batched_valset_qos(tensor_output: torch.Tensor):
dataset_len = len(self.val_loader.dataset) dataset_len = len(self.val_loader.dataset)
......
from .common_qos import accuracy
from .indexing import ModuleIndexer from .indexing import ModuleIndexer
from .summary import get_summary from .summary import get_summary
from .utils import (BatchedDataLoader, infer_net_device, from .utils import (BatchedDataLoader, infer_net_device,
......
from torch import Tensor
def accuracy(output: Tensor, target: Tensor) -> float:
_, pred_labels = output.max(1)
n_correct = (pred_labels == target).sum().item()
return n_correct / len(output)
...@@ -51,12 +51,12 @@ def get_flops(module: nn.Module, input_shape, output_shape): ...@@ -51,12 +51,12 @@ def get_flops(module: nn.Module, input_shape, output_shape):
if not handler: if not handler:
if not list(module.children()): if not list(module.children()):
_print_once(f"Leaf module {module} cannot be handled") _print_once(f"Leaf module {module} cannot be handled")
return None return 0.0
try: try:
return handler() return handler()
except RuntimeError as e: except RuntimeError as e:
_print_once(f'Error "{e}" when handling {module}') _print_once(f'Error "{e}" when handling {module}')
return None return 0.0
def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame: def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
......
import unittest
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16
class TestTorchAppInit(unittest.TestCase):
def setUp(self):
transform = transforms.Compose([transforms.ToTensor()])
self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
self.module = vgg16(pretrained=True)
def test_init(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
get_knobs_from_file(),
accuracy,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment