Skip to content
Snippets Groups Projects
Commit 5f46b4b1 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Use attribute tool, don't parse attributes by hand

parent baa23f74
No related branches found
No related tags found
No related merge requests found
......@@ -110,10 +110,11 @@ class DFG(object):
for k in unused_values:
self.tensors.pop(k)
def _allocate_insert_var(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0
def _split_node_args(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0, pop_pos: int = -1
) -> None:
varname = f"conv_{self._var_count}"
node1.input.pop(pop_pos)
node1.output = [varname]
node2.input[input_pos] = varname
self._var_count += 1
......@@ -211,7 +212,7 @@ class DFG(object):
# Add an intermediate var between conv and add
conv_node = g.Conv2DNode(onnx_node)
bias_node = g.BiasAddNode(onnx_node)
self._allocate_insert_var(conv_node, bias_node)
self._split_node_args(conv_node, bias_node)
return [conv_node, bias_node]
elif onnx_node.op_type in ("MatMul", "Gemm"):
if onnx_node.op_type == "Gemm":
......@@ -230,8 +231,7 @@ class DFG(object):
# Add an intermediate var between matmul and add
mul_node = g.MatMulNode(onnx_node)
bias_node = g.BiasAddNode(onnx_node)
self._allocate_insert_var(mul_node, bias_node)
mul_node.input.pop()
self._split_node_args(mul_node, bias_node)
return [mul_node, bias_node]
one_to_one_nodes = {
"MaxPool": g.MaxPool2DNode,
......
......@@ -4,8 +4,11 @@
from typing import List
import onnx
from onnx_attr import node_attr_to_dict
class DFGNode:
def __init__(self, onnx_node: onnx.NodeProto):
......@@ -41,9 +44,7 @@ class BiasAddNode(DFGNode):
def __init__(self, onnx_conv_node: onnx.NodeProto):
super().__init__(onnx_conv_node)
self.op_type = "BiasAdd"
self.input = list()
self.input.append(self.output[0])
self.input.append(onnx_conv_node.input[2])
self.input = [onnx_conv_node.output[0], onnx_conv_node.input[2]]
def codegen(self):
return "tensorAdd", []
......@@ -71,20 +72,13 @@ class SoftMaxNode(DFGNode):
class Conv2DNode(DFGNode):
def __init__(self, onnx_node: onnx.NodeProto):
super().__init__(onnx_node)
if len(self.input) == 3:
tmp_input = list()
for i in self.input:
tmp_input.append(i)
self.input = tmp_input
self.input.pop() # remove the last index for bias add
self.padding = 0
self.strides = list()
for attr in onnx_node.attribute:
if attr.name == "pads":
self.padding = attr.ints[0]
elif attr.name == "strides":
for stride in attr.ints:
self.strides.append(stride)
attrs = node_attr_to_dict(onnx_node)
padding = attrs["pads"]
assert len(padding) == 4, "2D convolution must have 4 padding values"
if any(p != padding[0] for p in padding[1:]):
raise ValueError("Convolution with different padding is unsupported")
self.padding = padding[0]
self.strides = attrs["strides"]
def codegen(self):
return (
......@@ -102,17 +96,11 @@ class Conv2DNode(DFGNode):
class MaxPool2DNode(DFGNode):
def __init__(self, onnx_node: onnx.NodeProto):
super().__init__(onnx_node)
self.strides = list()
self.pool_size = list()
attr = node_attr_to_dict(onnx_node)
self.strides = attr["strides"]
self.pool_size = attr["kernel_shape"]
self.padding = 0
self.pool_type = "0"
for attr in onnx_node.attribute:
if attr.name == "kernel_shape":
for pool in attr.ints:
self.pool_size.append(pool)
elif attr.name == "strides":
for stride in attr.ints:
self.strides.append(stride)
def codegen(self):
return (
......@@ -136,17 +124,11 @@ class MaxPool2DNode(DFGNode):
class AveragePool2DNode(DFGNode):
def __init__(self, onnx_node: onnx.NodeProto):
super().__init__(onnx_node)
self.strides = list()
self.pool_size = list()
attr = node_attr_to_dict(onnx_node)
self.strides = attr["strides"]
self.pool_size = attr["kernel_shape"]
self.padding = 0
self.pool_type = "1"
for attr in onnx_node.attribute:
if attr.name == "kernel_shape":
for pool in attr.ints:
self.pool_size.append(pool)
elif attr.name == "strides":
for stride in attr.ints:
self.strides.append(stride)
def codegen(self):
return (
......@@ -186,10 +168,8 @@ class TanhNode(DFGNode):
class BatchNormalizationNode(DFGNode):
def __init__(self, onnx_node: onnx.NodeProto):
super().__init__(onnx_node)
self.epsilon = ""
for attr in onnx_node.attribute:
if attr.name == "epsilon":
self.epsilon = str(attr.f)
attr = node_attr_to_dict(onnx_node)
self.epsilon = attr["epsilon"]
def codegen(self):
return "tensorBatchNorm", [self.epsilon]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment