diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index db7cf208e4516ca2d3f221cb6effa31c5641cde6..f7260f2b2c1edfc1f88271e99386a1db6d20319d 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -1,15 +1,23 @@
+from collections import defaultdict
 from os import PathLike
 from pathlib import Path
+from typing import Dict, List, Optional, Tuple
 
 import networkx as nx
+import onnx
+
 import graph_ir as g
-from tensor import InputTensor, WeightTensor
-from collections import defaultdict
+from tensor import InputTensor, Tensor, WeightTensor
 
+ModelT = onnx.ModelProto
+GraphT = onnx.GraphProto
+NodeT = onnx.NodeProto
 
-class GraphBuilder(object):
-    def __init__(self, model, shape):
+
+class GraphBuilder:
+    def __init__(self, model: ModelT, shape: List[int] = None):
         self._check_model(model)
+        # TODO: what type is self.shape? pick one.
         self.shape = shape if shape else self._infer_shape(model.graph)
         self.tensors = self._extract_tensors_from_graph(model.graph)
         self.dfg = DFG(model.graph, self.tensors)
@@ -19,8 +27,9 @@ class GraphBuilder(object):
     ################################################
 
     @staticmethod
-    def _check_model(onnx_model):
+    def _check_model(onnx_model: ModelT):
         import warnings
+
         from onnx import checker, onnx_cpp2py_export
 
         if hasattr(checker, "check_model"):
@@ -31,7 +40,7 @@ class GraphBuilder(object):
                 warnings.warn(str(e))
 
     @staticmethod
-    def _infer_shape(onnx_graph):
+    def _infer_shape(onnx_graph: GraphT) -> Dict[str, List[int]]:
         shape = {}
         for input in onnx_graph.input:
             # get type of input tensor
@@ -42,7 +51,7 @@ class GraphBuilder(object):
         return shape
 
     @staticmethod
-    def _extract_tensors_from_graph(onnx_graph):
+    def _extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, Tensor]:
         tensors = {}
         # parse weight
         weight_cnt = 0
@@ -69,12 +78,7 @@ class GraphBuilder(object):
                     tensors[i] = InputTensor(i)
         return tensors
 
-    ################################################
-    # Top level Graph Building functions
-    # return the compilation-ready graph
-    ################################################
-
-    def dump_weights(self, output_dir: PathLike):
+    def dump_weights(self, output_dir: PathLike) -> None:
         output_dir = Path(output_dir)
         for tensor in self.tensors.values():
             if not isinstance(tensor, WeightTensor):
@@ -83,22 +87,26 @@ class GraphBuilder(object):
 
 
 class DFG(object):
-    def __init__(self, graph, tensors):
+    def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]):
         if len(graph.output) > 1:
             raise ValueError("Only single-output graph is supported")
-        self.inputs = graph.input
-        self.output = graph.output[0]
+        self.inputs: List[str] = graph.input
+        self.output: str = graph.output[0]
         self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
         self._var_count = 0
         self.tensors = tensors
         self.graph = self.build_dfg(graph)
 
     @property
-    def traverse_order(self):
+    def traverse_order(self) -> List[g.DFGNode]:
         return list(nx.topological_sort(self.graph))
 
     @staticmethod
-    def def_use(nodes):
+    def def_use(nodes: list) -> Tuple[dict, dict]:
+        """Computes def/use relation from a list of node.
+
+        This method is duck-typed and operates on any node defining .input and .output.
+        """
         defs, uses = {}, defaultdict(list)
         for n in nodes:
             for input_ in n.input:
@@ -107,19 +115,21 @@ class DFG(object):
                 defs[output] = n
         return defs, uses
 
-    def _allocate_insert_var(self, node1, node2, input_pos: int = 0):
+    def _allocate_insert_var(
+        self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0
+    ) -> None:
         varname = f"conv_{self._var_count}"
         node1.output = [varname]
         node2.input[input_pos] = varname
         self._var_count += 1
 
-    def detect_flatten(self, graph):
+    def detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]:
         # Look for a shape-gather-unsqueeze-concat chain
         nodes = graph.node
         included_nodes = {}  # Name to node
         generated_nodes = []
 
-        def get_next_in_chain(type_: str, node) -> str:
+        def get_next_in_chain(type_: str, node: Optional[NodeT]) -> Optional[NodeT]:
             """
             Get a unique user node of the unique output of Node `node`,
             and return it if it has Type `type_`.
@@ -164,7 +174,7 @@ class DFG(object):
                 generated_nodes.append(g.FlattenNode.from_onnx_idiom(nodes))
         return included_nodes, generated_nodes
 
-    def build_dfg(self, graph) -> nx.DiGraph:
+    def build_dfg(self, graph: GraphT) -> nx.DiGraph:
         error_nodes, generated_nodes = [], []
         used_onnx_nodes, flatten_nodes = self.detect_flatten(graph)
         generated_nodes.extend(flatten_nodes)
@@ -182,19 +192,19 @@ class DFG(object):
                 raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
             else:
                 raise ValueError(f"Unsupported operators: {error_repr}")
-        graph = nx.DiGraph()
+        ret_graph = nx.DiGraph()
         defs, uses = self.def_use(generated_nodes)
-        graph.add_nodes_from(generated_nodes)
+        ret_graph.add_nodes_from(generated_nodes)
         for onnx_value_name, use_nodes in uses.items():
             if onnx_value_name not in defs:
                 continue
             def_node = defs[onnx_value_name]
             for use_node in use_nodes:
-                graph.add_edge(def_node, use_node)
-        return graph
+                ret_graph.add_edge(def_node, use_node)
+        return ret_graph
 
     # This should be the place where partial evaluation happens
-    def emit_node(self, onnx_node):
+    def emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
         if onnx_node.op_type == "Conv":
             weight_tensor = self.tensors[onnx_node.input[1]]
             assert isinstance(weight_tensor, WeightTensor)
diff --git a/hpvm/projects/onnx/frontend/hpvm_codegen.py b/hpvm/projects/onnx/frontend/hpvm_codegen.py
index b269a7debffb7c3ae4adefb8e4e0caa329a631db..719e51dc1f102c3b134aca3b90709b5a98df9679 100644
--- a/hpvm/projects/onnx/frontend/hpvm_codegen.py
+++ b/hpvm/projects/onnx/frontend/hpvm_codegen.py
@@ -1,4 +1,6 @@
 from os import PathLike
+from pathlib import Path
+from typing import Dict, List, Tuple, Union
 
 import jinja2
 
@@ -12,22 +14,25 @@ template = template_env.get_template(TEMPLATE_FILE)
 
 
 class HpvmCodeGen:
-    def __init__(self, DFG: DFG, output_dir: PathLike):
-        self.dfg = DFG
-        self.tensors = DFG.tensors
+    def __init__(self, dfg: DFG, output_dir: PathLike):
+        self.dfg = dfg
+        self.tensors = dfg.tensors
         self.var_count = 0
-        self.output_dir = output_dir
+        self.output_dir = Path(output_dir)
         # self.variables is a "onnx name to our name" map
         # Each value is (varname, bool) and the bool indicates
         # "is root node input" or not.
-        self.variables = self._get_root_args(DFG.inputs, DFG.tensors)
+        IdenT = Union[str, int]
+        self.variables: Dict[str, Tuple[IdenT, bool]] = self._get_root_args(
+            dfg.inputs, dfg.tensors
+        )
 
     ################################################
     # Aux functions
     ################################################
 
     @staticmethod
-    def _get_root_args(input_nodes, tensors):
+    def _get_root_args(input_nodes, tensors) -> Dict[str, Tuple[int, bool]]:
         # Input to the graph + all weight tensors
         # Sometimes these 2 kinds can overlap (due to ONNX optim)
         # We'll dedup this array as well.
@@ -40,31 +45,16 @@ class HpvmCodeGen:
         root_args = sorted(list(set(root_args)))
         return {f_name: (index, True) for index, f_name in enumerate(root_args)}
 
-    def _allocate_varname(self):
+    def _allocate_varname(self) -> str:
         varname = f"var_{self.var_count}"
         self.var_count += 1
         return varname
 
-    def get_varname_of(self, onnx_var_name):
-        if onnx_var_name in self.root_args:
-            return True, self.root_args[onnx_var_name]
-        elif onnx_var_name in self.local_vars:
-            return False, self.local_vars[onnx_var_name]
-        else:
-            raise KeyError(onnx_var_name)
-
-    @staticmethod
-    def transform_name(name: str):
-        name = name.replace(".", "_")
-        if name[0].isnumeric():
-            name = "_" + name
-        return name
-
     ################################################
     # CodeGen functions
     ################################################
 
-    def _emit_hpvm_node_edges(self, input_vars):
+    def _emit_hpvm_node_edges(self, input_vars: List[str]) -> List[dict]:
         ret = []
         it = 0
         for onnx_var_name in input_vars:
@@ -81,7 +71,7 @@ class HpvmCodeGen:
             it += 1
         return ret
 
-    def emit_hpvm_node_structures(self):
+    def emit_hpvm_node_structures(self) -> List[dict]:
         node_envs = []
         for node in self.dfg.traverse_order:
             generated_code = node.hpvm_codegen(self.tensors)
@@ -103,26 +93,26 @@ class HpvmCodeGen:
             )
         return node_envs
 
-    def emit_root_io(self):
+    def emit_root_io(self) -> Tuple[List[str], str]:
         input_args = [
-            self.transform_name(name)
+            make_c_identifier(name)
             for name, (_, is_root) in self.variables.items()
             if is_root
         ]
         output_arg = self.variables[self.dfg.output.name][0]
         return input_args, output_arg
 
-    def emit_weights(self):
+    def emit_weights(self) -> List[dict]:
         ret = []
         for name, tensor in self.tensors.items():
             if not isinstance(tensor, WeightTensor):
                 continue
-            name = self.transform_name(name)
+            name = make_c_identifier(name)
             file_path = f"{tensor.get_mapped_name()}_path.bin"
             ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
         return ret
 
-    def compile(self):
+    def compile(self) -> None:
         nodes = self.emit_hpvm_node_structures()
         inputs, output = self.emit_root_io()
         weights = self.emit_weights()
@@ -136,3 +126,10 @@ class HpvmCodeGen:
                     output_dir=self.output_dir,
                 )
             )
+
+
+def make_c_identifier(name: str) -> str:
+    name = name.replace(".", "_")
+    if name[0].isnumeric():
+        name = "_" + name
+    return name
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 7e3809d0eb6c635709b8727ac5a07738db3373c2..161b6e8efcac0128142e2475a67fdc81340d05e4 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -1,5 +1,6 @@
 from pathlib import Path
-from typing import Iterable, Optional
+from typing import List, Optional
+
 import onnx
 
 
@@ -24,7 +25,7 @@ def check_version(model, new_version):
 
 def compile(
     model,
-    input_size: Iterable[int],
+    input_size: Optional[List[int]],
     output_dir: Path,
     opset_version: Optional[int],
     hpvmc: bool,
@@ -35,7 +36,7 @@ def compile(
 
     if opset_version is not None:
         model = check_version(model, opset_version)
-    graphBuilder = GraphBuilder(model, output_dir)
+    graphBuilder = GraphBuilder(model)
     if hpvmc:
         hpvmCodeGen = HpvmCodeGen(graphBuilder.dfg, output_dir)
         hpvmCodeGen.compile()
@@ -54,7 +55,6 @@ def parse_args():
         "-s",
         "--input-size",
         type=int,
-        required=True,
         nargs="+",
         help="""Size of input tensor to the model.
 Usually 4 dim, including batch size.