From d5f71a5bcaffd64a6356f7d93904f275e742f2a3 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sun, 6 Dec 2020 03:32:13 -0600
Subject: [PATCH] Some comments

---
 hpvm/projects/onnx/frontend/graph_builder.py | 42 ++++++++++++++--
 hpvm/projects/onnx/frontend/graph_ir.py      | 52 ++++++++++++++------
 hpvm/projects/onnx/frontend/main.py          |  2 +-
 3 files changed, 76 insertions(+), 20 deletions(-)

diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index e6467bc158..3996b947d2 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -1,5 +1,4 @@
-from collections import defaultdict, namedtuple
-from onnx_attr import node_attr_to_dict
+from collections import defaultdict
 from os import PathLike
 from pathlib import Path
 from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -8,6 +7,7 @@ import networkx as nx
 import onnx
 
 import graph_ir as g
+from onnx_attr import node_attr_to_dict
 
 GraphT = onnx.GraphProto
 NodeT = onnx.NodeProto
@@ -16,6 +16,15 @@ NodeT.__repr__ = NodeT.__str__ = lambda self: self.name
 
 
 class MarkedSubGraph:
+    """A subgraph with information on how it should replace a node in a super graph.
+
+    subgraph: a nx.DiGraph subgraph
+    entry_edges: a list of edges from nodes "outside" to nodes in self.subgraph
+    exit: the exit node of the subgraph.
+        When this subgraph replaces a node `n`, self.exit will be connected to
+        whateven `n` is connected to.
+    """
+
     def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None:
         assert all(to in subgraph for _, to, _ in entry_edges)
         assert exit in subgraph
@@ -24,6 +33,9 @@ class MarkedSubGraph:
 
     @classmethod
     def idiomatic_1to2(cls, node1, node2, predecessors):
+        """Create an idiomatic replacement as follow:
+
+        node(arg1, arg2, arg3) -> node2(node1(arg1, arg2), arg3)"""
         p0, p1, p2 = predecessors
         graph = nx.DiGraph()
         graph.add_edge(node1, node2, index=0)
@@ -34,24 +46,36 @@ EmitNodeT = Union[MarkedSubGraph, g.DFGNode]
 
 
 class DFG(object):
+    """ONNX model translated into DFG with `DFGNode`s.
+
+    This class has a DFG, input/output information, and a clear traverse order
+    (think dominant tree), and is easier for CodeGen classes to work with."""
+
     def __init__(self, graph: GraphT):
         self._check_model(graph)
         self._var_count = 0
+        # Build explicit DFG with ONNX nodes
         onnx_graph = self._build_onnx_dfg(graph)
+        # Convert ONNX dfg into DFGNode DFG
         self.graph = self._build_dfg(onnx_graph)
+        # Find out input nodes and output node (unique)
+        # removing dead nodes along the way if any
         self.inputs, self.output = self._dce_get_io_info()
 
     ################ Interfaces:
 
     @property
     def traverse_order(self) -> List[g.DFGNode]:
+        """Get topological order of computational graph by use-def relation."""
         return list(nx.topological_sort(self.graph))
 
-    def node_args(self, node):
+    def node_args(self, node: g.DFGNode):
+        """Get input arguments of node."""
         sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2])
         return [e[0] for e in sorted_edges]
 
     def dump_weights(self, output_dir: PathLike) -> None:
+        """Dump `WeightTensor`s into output_dir."""
         output_dir = Path(output_dir)
         for node in self.graph.nodes:
             if not isinstance(node, g.WeightTensor):
@@ -62,6 +86,8 @@ class DFG(object):
 
     @staticmethod
     def _check_model(onnx_graph: GraphT):
+        """Check model validaty and single output (which is our limitation)"""
+
         import warnings
         from onnx import checker, onnx_cpp2py_export
 
@@ -93,7 +119,15 @@ class DFG(object):
         return ret_graph
 
     def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
+        """Translate _build_onnx_dfg output into DFGNode DFG.
+        
+        First run some passes to process subgraphs that needs to be
+        processed together, then each unprocessed node is generated into
+        1 or more nodes."""
+
+        # Remove subgraphs that can be a single Flatten instead
         onnx_graph = detect_flatten(onnx_graph)
+        # Remove subgraphs that look like padding but does nothing
         onnx_graph = remove_no_padding(onnx_graph)
         # For each onnx node, generate our nodes
         node_to_nodes, error_nodes = {}, []
@@ -109,6 +143,7 @@ class DFG(object):
                 raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
             else:
                 raise ValueError(f"Unsupported operators: {error_repr}")
+        # Apply node_to_nodes replacement on onnx_graph to create a new DFG
         return build_graph_with_mapping(onnx_graph, node_to_nodes)
 
     def _dce_get_io_info(self):
@@ -187,6 +222,7 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]:
 
 
 def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph:
+    """Remove subgraphs that look like padding but does nothing."""
     for node in list(graph.nodes):
         if node.op_type != "Pad":
             continue
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 9ffff9944c..fd42617f60 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -1,11 +1,6 @@
-################################################
-# Top Level DFGNode interface
-################################################
-
-
 import abc
 from os import PathLike
-from typing import List
+from typing import List, Tuple
 
 import onnx
 
@@ -13,27 +8,32 @@ from onnx_attr import node_attr_to_dict
 
 
 class DFGNode(abc.ABC):
+    """Abstract node that represents 1 instruction in HPVM.
+    
+    op_type should be overriden in subclasses for readability.
+    """
+
     op_type = ""
 
     def __init__(self, name: str, attrs: dict = {}):
         self.name, self.attrs = name, attrs
 
-    def codegen(self):
+    def codegen(self) -> Tuple[str, list]:
         return "", []
 
-    def hpvm_codegen(self):
+    def hpvm_codegen(self) -> Tuple[str, list]:
         return "", []
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"{self.op_type}({self.name})"
 
 
-################################################
-# Actual Implementation of Operators
-################################################
+class TensorNode(DFGNode, abc.ABC):
+    """An abstract node for a value that exists without an instruction.
 
+    This is akin to Value class in LLVM, but in a different place on the
+    inheritance tree."""
 
-class TensorNode(DFGNode):
     def __init__(self, proto: onnx.TensorProto, new_name: str):
         if not proto.name.strip():
             raise ValueError("Tensor's name is required.")
@@ -47,6 +47,12 @@ class TensorNode(DFGNode):
 
 
 class InputTensor(TensorNode):
+    """Input to the computation graph.
+    
+    This is basically only used for its information about the ONNX input,
+    itself doesn't emit instruction or any interesting thing.
+    """
+
     op_type = "InputTensor"
 
     def __init__(self, input_proto: onnx.TensorProto, new_name: str):
@@ -59,6 +65,12 @@ class InputTensor(TensorNode):
 
 
 class WeightTensor(TensorNode):
+    """An initialized parameter in ONNX graph.
+    
+    This is any parameter that has a initializer value in the ONNX model
+    (as opposed to InputTensor, which doesn't have any value).
+    """
+
     op_type = "WeightTensor"
 
     def __init__(self, weight_proto: onnx.TensorProto, new_name: str):
@@ -113,6 +125,8 @@ class Conv2DNode(DFGNode):
 
 
 class _Pool2DNode(DFGNode, abc.ABC):
+    """Common super class of Average pooling and Max pooling."""
+
     pool_type = "0"
 
     def __init__(self, name: str, attrs: dict):
@@ -170,7 +184,13 @@ class MatMulNode(DFGNode):
         return "__visc__tensor_mul", []
 
     @staticmethod
-    def gemm_transpose(onnx_node, predec):
+    def gemm_transpose(onnx_gemm_node, predec):
+        """Find and transpose weights of the onnx gemm node.
+        
+        This way we transpose the constant weight instead of exporting
+        a transpose node (which doesn't yet exist in HPVM).
+        """
+
         def _transpose(weight):
             if not isinstance(weight, WeightTensor):
                 raise ValueError(
@@ -179,7 +199,7 @@ class MatMulNode(DFGNode):
             weight.transpose_()
 
         # Some tensors may need transposing
-        attrs = node_attr_to_dict(onnx_node)
+        attrs = node_attr_to_dict(onnx_gemm_node)
         if attrs.get("transA", False):
             _transpose(predec[0])
         if attrs.get("transB", False):
@@ -257,7 +277,7 @@ class ActivationNode(DFGNode):
 
 class LogicalOpNode(DFGNode):
     """
-    ELement wise operators that is not for activation function.
+    Element wise operators that is not for activation function.
     In other words, they are logical comparison operators
     e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
     Or, Xor
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 25e76ae8fa..d3037224ac 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -90,7 +90,7 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""",
 
     args = parser.parse_args()
     args.hpvmc = args.compile_mode == "hpvmc"
-    delattr(args, 'compile_mode')
+    delattr(args, "compile_mode")
     return args
 
 
-- 
GitLab