Skip to content
Snippets Groups Projects
Commit db54bdd8 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed corner cases and added test for generating all 10 benchmarks

parent 0547093f
No related branches found
No related tags found
No related merge requests found
......@@ -221,9 +221,13 @@ class DFG(object):
"Identity": g.IdentityNode,
"Flatten": g.FlattenNode,
}
if node.op_type in one_to_one_nodes:
if node.op_type not in one_to_one_nodes:
return None
try:
return one_to_one_nodes[node.op_type](node.name, **attrs)
return None
except (TypeError, KeyError, ValueError, RuntimeError):
node_class = one_to_one_nodes[node.op_type]
raise ValueError(f"Node ({node_class}) creation failed")
def def_use(nodes: Iterable) -> Tuple[dict, dict]:
......
......@@ -134,18 +134,23 @@ class Conv2DNode(DFGNode):
raise ValueError("Convolution with different padding is unsupported")
if list(dilations) != [1, 1]:
raise ValueError("Dilation > 1 is unsupported")
if group != 1:
raise ValueError("Group > 1 is unsupported")
self.group = group
self.pads = pads[0]
self.sh, self.sw = strides
def codegen(self):
return (
"tensorConvolution",
[self.pads, self.pads, self.sh, self.sw],
[self.pads, self.pads, self.sh, self.sw, 1, self.group],
)
def hpvm_codegen(self):
if self.group != 1:
return (
"__hpvm__tensor_group_convolution",
# 1 is conv_mode -- should always be 1
[self.pads, self.pads, self.sh, self.sw, 1, self.group],
)
return (
"__hpvm__tensor_convolution",
[self.pads, self.pads, self.sh, self.sw],
......@@ -323,7 +328,8 @@ class BatchNormalizationNode(DFGNode):
input_shapes: Tuple[ShapeT, ShapeT],
output_shape: ShapeT,
epsilon: float,
axis: int,
axis: int = None,
momentum: float = None,
):
super().__init__(name, input_shapes, output_shape)
self.epsilon = epsilon
......
from .alexnet import AlexNet, AlexNet2, AlexNetImageNet
from .datasets import CIFAR, MNIST, ImageNet
from .lenet import LeNet
from .vgg16 import VGG16Cifar10, VGG16Cifar100
from .vgg16 import VGG16Cifar10, VGG16Cifar100, VGG16ImageNet
from .mobilenet import MobileNet
from .resnet import ResNet18, ResNet50
......@@ -39,3 +39,8 @@ class VGG16Cifar10(_VGG16):
class VGG16Cifar100(_VGG16):
def __init__(self):
super().__init__([512, 512, 100])
class VGG16ImageNet(_VGG16):
def __init__(self):
super().__init__([25088, 4096, 4096, 1000])
import os
import site
from pathlib import Path
from torch2hpvm import ModelExporter, BinDataset
from dnn import AlexNet
import shutil
prefix = Path(__file__).parent / "../model_params/alexnet_cifar10"
dataset_shape = 5000, 3, 32, 32
bin_tuneset = BinDataset(prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape)
bin_testset = BinDataset(prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape)
ModelExporter(AlexNet(), bin_tuneset, bin_testset, "/tmp/alexnet", True).export_all()
from torch2hpvm import BinDataset, ModelExporter
site.addsitedir(os.path.dirname(__file__))
import dnn
benchmarks = [
(dnn.LeNet, 1, 28, "lenet_mnist"),
(dnn.AlexNet, 3, 32, "alexnet_cifar10"),
(dnn.AlexNet2, 3, 32, "alexnet2_cifar10"),
(dnn.AlexNetImageNet, 3, 224, "alexnet_imagenet"),
(dnn.MobileNet, 3, 32, "mobilenet_cifar10"),
(dnn.ResNet18, 3, 32, "resnet18_cifar10"),
(dnn.ResNet50, 3, 224, "resnet50_imagenet"),
(dnn.VGG16Cifar10, 3, 32, "vgg16_cifar10"),
(dnn.VGG16Cifar100, 3, 32, "vgg16_cifar100"),
(dnn.VGG16ImageNet, 3, 224, "vgg16_imagenet"),
]
self_folder = Path(__file__).parent
for model_cls, nch, img_size, pathname in benchmarks:
target = Path(f"/tmp/{pathname}")
print(f"Generating {pathname} to {target}")
if target.exists():
shutil.rmtree(target)
prefix = self_folder / "../model_params" / pathname
dataset_shape = 5000, nch, img_size, img_size
bin_tuneset = BinDataset(
prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape
)
bin_testset = BinDataset(
prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape
)
ModelExporter(model_cls(), bin_tuneset, bin_testset, target, True).export_all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment