Skip to content
Snippets Groups Projects
Commit d5f71a5b authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Some comments

parent f7ade276
No related branches found
No related tags found
No related merge requests found
from collections import defaultdict, namedtuple
from onnx_attr import node_attr_to_dict
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
......@@ -8,6 +7,7 @@ import networkx as nx
import onnx
import graph_ir as g
from onnx_attr import node_attr_to_dict
GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
......@@ -16,6 +16,15 @@ NodeT.__repr__ = NodeT.__str__ = lambda self: self.name
class MarkedSubGraph:
"""A subgraph with information on how it should replace a node in a super graph.
subgraph: a nx.DiGraph subgraph
entry_edges: a list of edges from nodes "outside" to nodes in self.subgraph
exit: the exit node of the subgraph.
When this subgraph replaces a node `n`, self.exit will be connected to
whateven `n` is connected to.
"""
def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None:
assert all(to in subgraph for _, to, _ in entry_edges)
assert exit in subgraph
......@@ -24,6 +33,9 @@ class MarkedSubGraph:
@classmethod
def idiomatic_1to2(cls, node1, node2, predecessors):
"""Create an idiomatic replacement as follow:
node(arg1, arg2, arg3) -> node2(node1(arg1, arg2), arg3)"""
p0, p1, p2 = predecessors
graph = nx.DiGraph()
graph.add_edge(node1, node2, index=0)
......@@ -34,24 +46,36 @@ EmitNodeT = Union[MarkedSubGraph, g.DFGNode]
class DFG(object):
"""ONNX model translated into DFG with `DFGNode`s.
This class has a DFG, input/output information, and a clear traverse order
(think dominant tree), and is easier for CodeGen classes to work with."""
def __init__(self, graph: GraphT):
self._check_model(graph)
self._var_count = 0
# Build explicit DFG with ONNX nodes
onnx_graph = self._build_onnx_dfg(graph)
# Convert ONNX dfg into DFGNode DFG
self.graph = self._build_dfg(onnx_graph)
# Find out input nodes and output node (unique)
# removing dead nodes along the way if any
self.inputs, self.output = self._dce_get_io_info()
################ Interfaces:
@property
def traverse_order(self) -> List[g.DFGNode]:
"""Get topological order of computational graph by use-def relation."""
return list(nx.topological_sort(self.graph))
def node_args(self, node):
def node_args(self, node: g.DFGNode):
"""Get input arguments of node."""
sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2])
return [e[0] for e in sorted_edges]
def dump_weights(self, output_dir: PathLike) -> None:
"""Dump `WeightTensor`s into output_dir."""
output_dir = Path(output_dir)
for node in self.graph.nodes:
if not isinstance(node, g.WeightTensor):
......@@ -62,6 +86,8 @@ class DFG(object):
@staticmethod
def _check_model(onnx_graph: GraphT):
"""Check model validaty and single output (which is our limitation)"""
import warnings
from onnx import checker, onnx_cpp2py_export
......@@ -93,7 +119,15 @@ class DFG(object):
return ret_graph
def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
"""Translate _build_onnx_dfg output into DFGNode DFG.
First run some passes to process subgraphs that needs to be
processed together, then each unprocessed node is generated into
1 or more nodes."""
# Remove subgraphs that can be a single Flatten instead
onnx_graph = detect_flatten(onnx_graph)
# Remove subgraphs that look like padding but does nothing
onnx_graph = remove_no_padding(onnx_graph)
# For each onnx node, generate our nodes
node_to_nodes, error_nodes = {}, []
......@@ -109,6 +143,7 @@ class DFG(object):
raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
else:
raise ValueError(f"Unsupported operators: {error_repr}")
# Apply node_to_nodes replacement on onnx_graph to create a new DFG
return build_graph_with_mapping(onnx_graph, node_to_nodes)
def _dce_get_io_info(self):
......@@ -187,6 +222,7 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]:
def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph:
"""Remove subgraphs that look like padding but does nothing."""
for node in list(graph.nodes):
if node.op_type != "Pad":
continue
......
################################################
# Top Level DFGNode interface
################################################
import abc
from os import PathLike
from typing import List
from typing import List, Tuple
import onnx
......@@ -13,27 +8,32 @@ from onnx_attr import node_attr_to_dict
class DFGNode(abc.ABC):
"""Abstract node that represents 1 instruction in HPVM.
op_type should be overriden in subclasses for readability.
"""
op_type = ""
def __init__(self, name: str, attrs: dict = {}):
self.name, self.attrs = name, attrs
def codegen(self):
def codegen(self) -> Tuple[str, list]:
return "", []
def hpvm_codegen(self):
def hpvm_codegen(self) -> Tuple[str, list]:
return "", []
def __repr__(self):
def __repr__(self) -> str:
return f"{self.op_type}({self.name})"
################################################
# Actual Implementation of Operators
################################################
class TensorNode(DFGNode, abc.ABC):
"""An abstract node for a value that exists without an instruction.
This is akin to Value class in LLVM, but in a different place on the
inheritance tree."""
class TensorNode(DFGNode):
def __init__(self, proto: onnx.TensorProto, new_name: str):
if not proto.name.strip():
raise ValueError("Tensor's name is required.")
......@@ -47,6 +47,12 @@ class TensorNode(DFGNode):
class InputTensor(TensorNode):
"""Input to the computation graph.
This is basically only used for its information about the ONNX input,
itself doesn't emit instruction or any interesting thing.
"""
op_type = "InputTensor"
def __init__(self, input_proto: onnx.TensorProto, new_name: str):
......@@ -59,6 +65,12 @@ class InputTensor(TensorNode):
class WeightTensor(TensorNode):
"""An initialized parameter in ONNX graph.
This is any parameter that has a initializer value in the ONNX model
(as opposed to InputTensor, which doesn't have any value).
"""
op_type = "WeightTensor"
def __init__(self, weight_proto: onnx.TensorProto, new_name: str):
......@@ -113,6 +125,8 @@ class Conv2DNode(DFGNode):
class _Pool2DNode(DFGNode, abc.ABC):
"""Common super class of Average pooling and Max pooling."""
pool_type = "0"
def __init__(self, name: str, attrs: dict):
......@@ -170,7 +184,13 @@ class MatMulNode(DFGNode):
return "__visc__tensor_mul", []
@staticmethod
def gemm_transpose(onnx_node, predec):
def gemm_transpose(onnx_gemm_node, predec):
"""Find and transpose weights of the onnx gemm node.
This way we transpose the constant weight instead of exporting
a transpose node (which doesn't yet exist in HPVM).
"""
def _transpose(weight):
if not isinstance(weight, WeightTensor):
raise ValueError(
......@@ -179,7 +199,7 @@ class MatMulNode(DFGNode):
weight.transpose_()
# Some tensors may need transposing
attrs = node_attr_to_dict(onnx_node)
attrs = node_attr_to_dict(onnx_gemm_node)
if attrs.get("transA", False):
_transpose(predec[0])
if attrs.get("transB", False):
......@@ -257,7 +277,7 @@ class ActivationNode(DFGNode):
class LogicalOpNode(DFGNode):
"""
ELement wise operators that is not for activation function.
Element wise operators that is not for activation function.
In other words, they are logical comparison operators
e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
Or, Xor
......
......@@ -90,7 +90,7 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""",
args = parser.parse_args()
args.hpvmc = args.compile_mode == "hpvmc"
delattr(args, 'compile_mode')
delattr(args, "compile_mode")
return args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment