From db54bdd812e0525590b347b444c1dd80669f6c18 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Mon, 1 Feb 2021 04:50:59 -0600
Subject: [PATCH] Fixed corner cases and added test for generating all 10
 benchmarks

---
 .../torch2hpvm/torch2hpvm/graph_builder.py    |  8 +++-
 .../torch2hpvm/torch2hpvm/graph_ir.py         | 14 +++++--
 .../dnn_benchmarks/pytorch/dnn/__init__.py    |  2 +-
 hpvm/test/dnn_benchmarks/pytorch/dnn/vgg16.py |  5 +++
 .../dnn_benchmarks/pytorch/test_frontend.py   | 42 +++++++++++++++----
 5 files changed, 57 insertions(+), 14 deletions(-)

diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
index 640c845ff2..7739a64f53 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
@@ -221,9 +221,13 @@ class DFG(object):
             "Identity": g.IdentityNode,
             "Flatten": g.FlattenNode,
         }
-        if node.op_type in one_to_one_nodes:
+        if node.op_type not in one_to_one_nodes:
+            return None
+        try:
             return one_to_one_nodes[node.op_type](node.name, **attrs)
-        return None
+        except (TypeError, KeyError, ValueError, RuntimeError):
+            node_class = one_to_one_nodes[node.op_type]
+            raise ValueError(f"Node ({node_class}) creation failed")
 
 
 def def_use(nodes: Iterable) -> Tuple[dict, dict]:
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
index 3309072723..a088e6eae5 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
@@ -134,18 +134,23 @@ class Conv2DNode(DFGNode):
             raise ValueError("Convolution with different padding is unsupported")
         if list(dilations) != [1, 1]:
             raise ValueError("Dilation > 1 is unsupported")
-        if group != 1:
-            raise ValueError("Group > 1 is unsupported")
+        self.group = group
         self.pads = pads[0]
         self.sh, self.sw = strides
 
     def codegen(self):
         return (
             "tensorConvolution",
-            [self.pads, self.pads, self.sh, self.sw],
+            [self.pads, self.pads, self.sh, self.sw, 1, self.group],
         )
 
     def hpvm_codegen(self):
+        if self.group != 1:
+            return (
+                "__hpvm__tensor_group_convolution",
+                # 1 is conv_mode -- should always be 1
+                [self.pads, self.pads, self.sh, self.sw, 1, self.group],
+            )
         return (
             "__hpvm__tensor_convolution",
             [self.pads, self.pads, self.sh, self.sw],
@@ -323,7 +328,8 @@ class BatchNormalizationNode(DFGNode):
         input_shapes: Tuple[ShapeT, ShapeT],
         output_shape: ShapeT,
         epsilon: float,
-        axis: int,
+        axis: int = None,
+        momentum: float = None,
     ):
         super().__init__(name, input_shapes, output_shape)
         self.epsilon = epsilon
diff --git a/hpvm/test/dnn_benchmarks/pytorch/dnn/__init__.py b/hpvm/test/dnn_benchmarks/pytorch/dnn/__init__.py
index 6d7deb9955..f4a16c6a2c 100644
--- a/hpvm/test/dnn_benchmarks/pytorch/dnn/__init__.py
+++ b/hpvm/test/dnn_benchmarks/pytorch/dnn/__init__.py
@@ -1,6 +1,6 @@
 from .alexnet import AlexNet, AlexNet2, AlexNetImageNet
 from .datasets import CIFAR, MNIST, ImageNet
 from .lenet import LeNet
-from .vgg16 import VGG16Cifar10, VGG16Cifar100
+from .vgg16 import VGG16Cifar10, VGG16Cifar100, VGG16ImageNet
 from .mobilenet import MobileNet
 from .resnet import ResNet18, ResNet50
diff --git a/hpvm/test/dnn_benchmarks/pytorch/dnn/vgg16.py b/hpvm/test/dnn_benchmarks/pytorch/dnn/vgg16.py
index 3c518fd511..43ac8f9818 100644
--- a/hpvm/test/dnn_benchmarks/pytorch/dnn/vgg16.py
+++ b/hpvm/test/dnn_benchmarks/pytorch/dnn/vgg16.py
@@ -39,3 +39,8 @@ class VGG16Cifar10(_VGG16):
 class VGG16Cifar100(_VGG16):
     def __init__(self):
         super().__init__([512, 512, 100])
+
+
+class VGG16ImageNet(_VGG16):
+    def __init__(self):
+        super().__init__([25088, 4096, 4096, 1000])
diff --git a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
index 214d1a4237..1ca46cd702 100644
--- a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
+++ b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
@@ -1,9 +1,37 @@
+import os
+import site
 from pathlib import Path
-from torch2hpvm import ModelExporter, BinDataset
-from dnn import AlexNet
+import shutil
 
-prefix = Path(__file__).parent / "../model_params/alexnet_cifar10"
-dataset_shape = 5000, 3, 32, 32
-bin_tuneset = BinDataset(prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape)
-bin_testset = BinDataset(prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape)
-ModelExporter(AlexNet(), bin_tuneset, bin_testset, "/tmp/alexnet", True).export_all()
+from torch2hpvm import BinDataset, ModelExporter
+
+site.addsitedir(os.path.dirname(__file__))
+import dnn
+
+benchmarks = [
+    (dnn.LeNet, 1, 28, "lenet_mnist"),
+    (dnn.AlexNet, 3, 32, "alexnet_cifar10"),
+    (dnn.AlexNet2, 3, 32, "alexnet2_cifar10"),
+    (dnn.AlexNetImageNet, 3, 224, "alexnet_imagenet"),
+    (dnn.MobileNet, 3, 32, "mobilenet_cifar10"),
+    (dnn.ResNet18, 3, 32, "resnet18_cifar10"),
+    (dnn.ResNet50, 3, 224, "resnet50_imagenet"),
+    (dnn.VGG16Cifar10, 3, 32, "vgg16_cifar10"),
+    (dnn.VGG16Cifar100, 3, 32, "vgg16_cifar100"),
+    (dnn.VGG16ImageNet, 3, 224, "vgg16_imagenet"),
+]
+self_folder = Path(__file__).parent
+for model_cls, nch, img_size, pathname in benchmarks:
+    target = Path(f"/tmp/{pathname}")
+    print(f"Generating {pathname} to {target}")
+    if target.exists():
+        shutil.rmtree(target)
+    prefix = self_folder / "../model_params" / pathname
+    dataset_shape = 5000, nch, img_size, img_size
+    bin_tuneset = BinDataset(
+        prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape
+    )
+    bin_testset = BinDataset(
+        prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape
+    )
+    ModelExporter(model_cls(), bin_tuneset, bin_testset, target, True).export_all()
-- 
GitLab