Skip to content
Snippets Groups Projects
Commit 50cd9c21 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Made graph_ir node arguments explicit

parent 8068fd1f
No related branches found
No related tags found
No related merge requests found
......@@ -175,14 +175,14 @@ class DFG(object):
assert isinstance(weight_tensor, g.WeightTensor)
if len(weight_tensor.shape) != 4:
return None # Only supports 2D conv
conv_node = g.Conv2DNode(node.name, attrs)
conv_node = g.Conv2DNode(node.name, **attrs)
if len(predec) == 2:
return conv_node
# Split into conv followed by an addition
bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
elif node.op_type in ("MatMul", "Gemm"):
mul_node = g.MatMulNode(node.name, attrs)
mul_node = g.MatMulNode(node.name, **attrs)
if node.op_type == "Gemm":
mul_node.gemm_transpose(node, predec)
if len(predec) == 2:
......@@ -203,7 +203,7 @@ class DFG(object):
"Flatten": g.FlattenNode,
}
if node.op_type in one_to_one_nodes:
return one_to_one_nodes[node.op_type](node.name, attrs)
return one_to_one_nodes[node.op_type](node.name, **attrs)
return None
......
......@@ -15,8 +15,8 @@ class DFGNode(abc.ABC):
op_type = ""
def __init__(self, name: str, attrs: dict = {}):
self.name, self.attrs = name, attrs
def __init__(self, name: str, **kwargs):
self.name = name
def codegen(self) -> Tuple[str, list]:
return "", []
......@@ -37,7 +37,7 @@ class TensorNode(DFGNode, abc.ABC):
def __init__(self, proto: onnx.TensorProto, new_name: str):
if not proto.name.strip():
raise ValueError("Tensor's name is required.")
super().__init__(proto.name, {})
super().__init__(proto.name)
self.new_name = new_name
def __str__(self):
......@@ -102,25 +102,28 @@ class WeightTensor(TensorNode):
class Conv2DNode(DFGNode):
op_type = "Conv2D"
def __init__(self, name: str, attrs: dict):
super().__init__(name, attrs)
padding = self.attrs["pads"]
assert len(padding) == 4, "2D convolution must have 4 padding values"
if any(p != padding[0] for p in padding[1:]):
def __init__(self, name: str, pads, strides, dilations, group: int, kernel_shape):
super().__init__(name)
assert len(pads) == 4, "2D convolution must have 4 padding values"
if any(p != pads[0] for p in pads[1:]):
raise ValueError("Convolution with different padding is unsupported")
self.padding = padding[0]
self.strides = self.attrs["strides"]
if dilations != [1, 1]:
raise ValueError("Dilation > 1 is unsupported")
if group != 1:
raise ValueError("Group > 1 is unsupported")
self.pads = pads[0]
self.strides = strides
def codegen(self):
return (
"tensorConvolution",
[self.padding, self.padding, self.strides[0], self.strides[1]],
[self.pads, self.pads, self.strides[0], self.strides[1]],
)
def hpvm_codegen(self):
return (
"__visc__tensor_convolution",
[self.padding, self.padding, self.strides[0], self.strides[1]],
[self.pads, self.pads, self.strides[0], self.strides[1]],
)
......@@ -129,20 +132,24 @@ class _Pool2DNode(DFGNode, abc.ABC):
pool_type = "0"
def __init__(self, name: str, attrs: dict):
super().__init__(name, attrs)
self.strides = self.attrs["strides"]
self.pool_size = self.attrs["kernel_shape"]
self.padding = 0
def __init__(self, name: str, strides, kernel_shape, pads, ceil_mode: int):
super().__init__(name)
self.strides = strides
self.kernel_shape = kernel_shape
pt, pb, pl, pr = pads
if pt != pb or pl != pr:
raise ValueError("Unequal padding on either side of same axis is unsupported")
self.pads = pt, pl
if ceil_mode != 0:
raise ValueError("Only ceil_mode == 0 is supported")
def codegen(self):
return (
"tensorPooling",
[
self.pool_type,
*self.pool_size,
self.padding,
self.padding,
*self.kernel_shape,
*self.pads,
*self.strides,
],
)
......@@ -150,7 +157,7 @@ class _Pool2DNode(DFGNode, abc.ABC):
def hpvm_codegen(self):
return (
"__visc__tensor_pool_max",
[*self.pool_size, self.padding, self.padding, *self.strides],
[*self.kernel_shape, *self.pads, *self.strides],
)
......@@ -249,9 +256,9 @@ class TanhNode(DFGNode):
class BatchNormalizationNode(DFGNode):
op_type = "BN"
def __init__(self, name: str, attrs: dict):
super().__init__(name, attrs)
self.epsilon = self.attrs["epsilon"]
def __init__(self, name: str, epsilon: float, axis: int):
super().__init__(name)
self.epsilon = epsilon
def codegen(self):
return "tensorBatchNorm", [self.epsilon]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment