Skip to content
Snippets Groups Projects
Commit c3ade693 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Separate matmul and add nodes

parent 02bcd130
No related branches found
No related tags found
No related merge requests found
......@@ -107,10 +107,11 @@ class DFG(object):
defs[output] = n
return defs, uses
def _allocate_var(self):
def _allocate_insert_var(self, node1, node2, input_pos: int = 0):
varname = f"conv_{self._var_count}"
node1.output = [varname]
node2.input[input_pos] = varname
self._var_count += 1
return varname
def detect_flatten(self, graph):
# Look for a shape-gather-unsqueeze-concat chain
......@@ -203,17 +204,24 @@ class DFG(object):
return [g.Conv2DNode(onnx_node)]
else:
# Add an intermediate var between conv and add
interm_var = self._allocate_var()
conv_node = g.Conv2DNode(onnx_node)
conv_node.output = [interm_var]
bias_node = g.BiasAddNode(onnx_node)
bias_node.input[0] = interm_var
self._allocate_insert_var(conv_node, bias_node)
return [conv_node, bias_node]
elif onnx_node.op_type in ("MatMul", "Gemm"):
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
if len(onnx_node.input) == 2:
return [g.MatMulNode(onnx_node)]
else:
# Add an intermediate var between matmul and add
mul_node = g.MatMulNode(onnx_node)
bias_node = g.BiasAddNode(onnx_node)
self._allocate_insert_var(mul_node, bias_node)
return [mul_node, bias_node]
one_to_one_nodes = {
"MaxPool": g.MaxPool2DNode,
"AveragePool": g.AveragePool2DNode,
"MatMul": g.MatMulNode,
"Gemm": g.MatMulNode,
"Add": g.AddNode,
"Softmax": g.SoftMaxNode,
"Relu": g.ReluNode,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment