Skip to content
Snippets Groups Projects
Commit b58b8819 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Remove unused values (DCE)

parent 7cd2a180
No related branches found
No related tags found
No related merge requests found
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import networkx as nx
import onnx
......@@ -69,7 +69,8 @@ class DFG(object):
self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
self._var_count = 0
self.tensors = tensors
self._graph = self.build_dfg(graph)
self._graph = self._build_dfg(graph)
self._dce() # Remove unused values
@property
def traverse_order(self) -> List[g.DFGNode]:
......@@ -88,7 +89,7 @@ class DFG(object):
return inputs[0]
@staticmethod
def def_use(nodes: list) -> Tuple[dict, dict]:
def def_use(nodes: Iterable) -> Tuple[dict, dict]:
"""Computes def/use relation from a list of node.
This method is duck-typed and operates on any node defining .input and .output.
......@@ -101,6 +102,13 @@ class DFG(object):
defs[output] = n
return defs, uses
def _dce(self):
_, uses = self.def_use(self._graph.nodes)
used_values = set(uses.keys())
unused_values = set(self.tensors.keys()) - used_values
for k in unused_values:
self.tensors.pop(k)
def _allocate_insert_var(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0
) -> None:
......@@ -109,7 +117,7 @@ class DFG(object):
node2.input[input_pos] = varname
self._var_count += 1
def detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]:
def _detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]:
# Look for a shape-gather-unsqueeze-concat chain
nodes = graph.node
included_nodes = {} # Name to node
......@@ -160,14 +168,14 @@ class DFG(object):
generated_nodes.append(g.FlattenNode.from_onnx_idiom(nodes))
return included_nodes, generated_nodes
def build_dfg(self, graph: GraphT) -> nx.DiGraph:
def _build_dfg(self, graph: GraphT) -> nx.DiGraph:
error_nodes, generated_nodes = [], []
used_onnx_nodes, flatten_nodes = self.detect_flatten(graph)
used_onnx_nodes, flatten_nodes = self._detect_flatten(graph)
generated_nodes.extend(flatten_nodes)
for onnx_node in graph.node:
if onnx_node.name in used_onnx_nodes:
continue
our_node = self.emit_node(onnx_node)
our_node = self._emit_node(onnx_node)
if our_node is None:
error_nodes.append(onnx_node)
else:
......@@ -190,7 +198,7 @@ class DFG(object):
return ret_graph
# This should be the place where partial evaluation happens
def emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
def _emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
if onnx_node.op_type == "Conv":
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment