Skip to content
Snippets Groups Projects
Commit 7cd2a180 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Simplify input finding

parent 0c844745
No related branches found
No related tags found
No related merge requests found
......@@ -4,11 +4,11 @@ from typing import Dict, List, Tuple, Union
import jinja2
from graph_builder import DFG, NodeT
from graph_builder import DFG
from tensor import Tensor, WeightTensor
TEMPLATE_FILE = "template_hpvm.cpp"
loader = jinja2.FileSystemLoader(searchpath="./")
loader = jinja2.FileSystemLoader(searchpath=Path(__file__).parent)
template_env = jinja2.Environment(loader=loader, trim_blocks=True)
template = template_env.get_template(TEMPLATE_FILE)
......@@ -23,9 +23,10 @@ class HpvmCodeGen:
# Each value is (varname, bool) and the bool indicates
# "is root node input" or not.
IdenT = Union[str, int]
self.variables: Dict[str, Tuple[IdenT, bool]] = get_input_args(
dfg.inputs, dfg.tensors
)
root_args = sorted([t.name for t in self.tensors.values()])
self.variables: Dict[str, Tuple[IdenT, bool]] = {
f_name: (index, True) for index, f_name in enumerate(root_args)
}
################################################
# Aux functions
......@@ -75,7 +76,7 @@ class HpvmCodeGen:
"input_size": len(node.input),
"edges": self._emit_hpvm_node_edges(node.input),
"call_name": func_name,
"call_args": extra_args
"call_args": extra_args,
}
)
return node_envs
......@@ -86,7 +87,7 @@ class HpvmCodeGen:
for name, (_, is_root) in self.variables.items()
if is_root
]
output_arg = self.variables[self.dfg.output.name][0]
output_arg = self.variables[self.dfg.output][0]
return input_args, output_arg
def compile(self) -> None:
......@@ -121,19 +122,3 @@ def emit_weights(tensors: Dict[str, Tensor]) -> List[dict]:
file_path = f"{tensor.new_name}_path.bin"
ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
return ret
def get_input_args(
input_nodes: List[NodeT], tensors: Dict[str, Tensor]
) -> Dict[str, Tuple[int, bool]]:
# Input to the graph + all weight tensors
# Sometimes these 2 kinds can overlap (due to ONNX optim)
# We'll dedup this array as well.
root_args = []
for i in input_nodes:
root_args.append(i.name)
for tensor in tensors.values():
if isinstance(tensor, WeightTensor):
root_args.append(tensor.name)
root_args = sorted(list(set(root_args)))
return {f_name: (index, True) for index, f_name in enumerate(root_args)}
......@@ -4,11 +4,11 @@ from typing import Dict, List, Optional, Union
import jinja2
from codegen_hpvm import emit_weights, get_input_args, make_c_identifier
from codegen_hpvm import emit_weights, make_c_identifier
from graph_builder import DFG
TEMPLATE_FILE = "template_tensor.cpp"
loader = jinja2.FileSystemLoader(searchpath="./")
loader = jinja2.FileSystemLoader(searchpath=Path(__file__).parent)
template_env = jinja2.Environment(loader=loader, trim_blocks=True)
template = template_env.get_template(TEMPLATE_FILE)
......@@ -27,10 +27,8 @@ class TensorCodeGen:
# Each value is (varname, bool) and the bool indicates
# "is root node input" or not.
IdenT = Union[str, int]
# model_inputs (can contain constants) is different from input_arg (which is the "real" input)
model_inputs = get_input_args(dfg.inputs, dfg.tensors)
self.variables: Dict[str, IdenT] = {
k: make_c_identifier(k) for k in model_inputs
k: make_c_identifier(k) for k in self.tensors
}
################################################
......@@ -70,7 +68,7 @@ class TensorCodeGen:
def compile(self):
graph_code = self.emit_graph()
output_arg = self.variables[self.dfg.output.name]
output_arg = self.variables[self.dfg.output]
with open(self.output_dir / "src.cc", "w") as f:
f.write(
template.render(
......
......@@ -65,16 +65,15 @@ class DFG(object):
def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]):
if len(graph.output) > 1:
raise ValueError("Only single-output graph is supported")
self.inputs: List[NodeT] = graph.input
self.output: NodeT = graph.output[0]
self.output: str = graph.output[0].name
self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
self._var_count = 0
self.tensors = tensors
self.graph = self.build_dfg(graph)
self._graph = self.build_dfg(graph)
@property
def traverse_order(self) -> List[g.DFGNode]:
return list(nx.topological_sort(self.graph))
return list(nx.topological_sort(self._graph))
def discover_input_var(self) -> Tuple[str, InputTensor]:
"""Guess which input tensor is the "input" to the ONNX model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment