Skip to content
Snippets Groups Projects
Commit f7ade276 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added support for unused pad node

parent 30500ca3
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ import graph_ir as g
GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
NodeT.__hash__ = lambda self: id(self)
NodeT.__repr__ = NodeT.__str__ = lambda self: self.name
class MarkedSubGraph:
......@@ -74,7 +75,8 @@ class DFG(object):
if len(onnx_graph.output) > 1:
raise ValueError("Graph must have single output")
def _build_onnx_dfg(self, graph: GraphT) -> nx.DiGraph:
@staticmethod
def _build_onnx_dfg(graph: GraphT) -> nx.DiGraph:
"""Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto.
DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit."""
......@@ -92,6 +94,7 @@ class DFG(object):
def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
onnx_graph = detect_flatten(onnx_graph)
onnx_graph = remove_no_padding(onnx_graph)
# For each onnx node, generate our nodes
node_to_nodes, error_nodes = {}, []
for onnx_node in nx.topological_sort(onnx_graph):
......@@ -183,6 +186,26 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]:
return defs, uses
def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph:
for node in list(graph.nodes):
if node.op_type != "Pad":
continue
input_args = sorted_inputs(graph, node)
# Find the second input argument to Pad (will be a Constant node)
# and take that away as well.
nct = input_args[1]
padding = node_attr_to_dict(nct)["value"]
if any(p != 0 for p in padding):
continue
# Connect input of Pad to where output of Pad goes
succ = graph.out_edges(node, "index")
for _, to, index in succ:
graph.add_edge(input_args[0], to, index=index)
# Remove nodes
graph.remove_nodes_from([node, nct])
return graph
def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph:
"""Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""
......@@ -280,3 +303,11 @@ def extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, g.TensorNode]:
def sorted_inputs(graph: nx.DiGraph, node):
sorted_edges = sorted(graph.in_edges(node, "index"), key=lambda p: p[2])
return [e[0] for e in sorted_edges]
def draw_graph(graph: nx.DiGraph, output_to):
from networkx.drawing.nx_agraph import to_agraph
agraph = to_agraph(graph)
agraph.layout("dot")
agraph.draw(output_to)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment