Skip to content
Snippets Groups Projects
Commit 92a20e60 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

TorchApp passed initialization test

parent ceaca2c9
No related branches found
No related tags found
No related merge requests found
......@@ -352,14 +352,11 @@ class FP16Approx(TorchApproxKnob):
return True
@property
def applicable_op_types(self) -> List[Type[Module]]:
return [Conv2d, Linear]
def expected_speedup(self) -> float:
return self.exp_speedup
def is_less_approx(self, other: TorchApproxKnob) -> Optional[bool]:
return None
def is_applicable(self, op: Module) -> bool:
return isinstance(op, (Conv2d, Linear))
class FP16ApproxModule(Module):
def __init__(self, module: Module):
......
import abc
from typing import Any, Callable, List, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Set, Tuple, Union
import numpy as np
import torch
......@@ -50,6 +50,7 @@ class TorchApp(ModeledApp, abc.ABC):
def __init__(
self,
app_name: str,
module: Module,
val_loader: DataLoader,
test_loader: DataLoader,
......@@ -58,7 +59,7 @@ class TorchApp(ModeledApp, abc.ABC):
combine_qos: Callable[[np.ndarray], float] = np.mean,
device: Union[torch.device, str] = _default_device,
) -> None:
super().__init__()
self.app_name = app_name
self.module = module
self.val_loader = val_loader
self.test_loader = test_loader
......@@ -67,6 +68,7 @@ class TorchApp(ModeledApp, abc.ABC):
self.combine_qos = combine_qos
self.device = device
self.module = self.module.to(device)
self.midx = ModuleIndexer(module)
self._op_costs = {}
self._op_knobs = {}
......@@ -79,6 +81,17 @@ class TorchApp(ModeledApp, abc.ABC):
]
self._op_costs[op_name] = summary.loc[op_name, "flops"]
# Init parent class last
super().__init__()
@property
def name(self) -> str:
return self.app_name
@property
def op_knobs(self) -> Dict[str, List[ApproxKnob]]:
return self._op_knobs
def get_models(self) -> List[Union[IPerfModel, IQoSModel]]:
def batched_valset_qos(tensor_output: torch.Tensor):
dataset_len = len(self.val_loader.dataset)
......
from .common_qos import accuracy
from .indexing import ModuleIndexer
from .summary import get_summary
from .utils import (BatchedDataLoader, infer_net_device,
......
from torch import Tensor
def accuracy(output: Tensor, target: Tensor) -> float:
_, pred_labels = output.max(1)
n_correct = (pred_labels == target).sum().item()
return n_correct / len(output)
......@@ -51,12 +51,12 @@ def get_flops(module: nn.Module, input_shape, output_shape):
if not handler:
if not list(module.children()):
_print_once(f"Leaf module {module} cannot be handled")
return None
return 0.0
try:
return handler()
except RuntimeError as e:
_print_once(f'Error "{e}" when handling {module}')
return None
return 0.0
def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
......
import unittest
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16
class TestTorchAppInit(unittest.TestCase):
def setUp(self):
transform = transforms.Compose([transforms.ToTensor()])
self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
self.module = vgg16(pretrained=True)
def test_init(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
get_knobs_from_file(),
accuracy,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment