Skip to content
Snippets Groups Projects
Commit 6fa47a82 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Free up some codegen function from member functions

parent 6b7c58d2
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,8 @@ from typing import Dict, List, Tuple, Union
import jinja2
from graph_builder import DFG
from tensor import WeightTensor
from graph_builder import DFG, NodeT
from tensor import Tensor, WeightTensor
TEMPLATE_FILE = "template_hpvm.cpp"
loader = jinja2.FileSystemLoader(searchpath="./")
......@@ -23,7 +23,7 @@ class HpvmCodeGen:
# Each value is (varname, bool) and the bool indicates
# "is root node input" or not.
IdenT = Union[str, int]
self.variables: Dict[str, Tuple[IdenT, bool]] = self._get_root_args(
self.variables: Dict[str, Tuple[IdenT, bool]] = get_input_args(
dfg.inputs, dfg.tensors
)
......@@ -31,20 +31,6 @@ class HpvmCodeGen:
# Aux functions
################################################
@staticmethod
def _get_root_args(input_nodes, tensors) -> Dict[str, Tuple[int, bool]]:
# Input to the graph + all weight tensors
# Sometimes these 2 kinds can overlap (due to ONNX optim)
# We'll dedup this array as well.
root_args = []
for i in input_nodes:
root_args.append(i.name)
for tensor in tensors.values():
if isinstance(tensor, WeightTensor):
root_args.append(tensor.name)
root_args = sorted(list(set(root_args)))
return {f_name: (index, True) for index, f_name in enumerate(root_args)}
def _allocate_varname(self) -> str:
varname = f"var_{self.var_count}"
self.var_count += 1
......@@ -103,20 +89,10 @@ class HpvmCodeGen:
output_arg = self.variables[self.dfg.output.name][0]
return input_args, output_arg
def emit_weights(self) -> List[dict]:
ret = []
for name, tensor in self.tensors.items():
if not isinstance(tensor, WeightTensor):
continue
name = make_c_identifier(name)
file_path = f"{tensor.get_mapped_name()}_path.bin"
ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
return ret
def compile(self) -> None:
nodes = self.emit_hpvm_node_structures()
inputs, output = self.emit_root_io()
weights = self.emit_weights()
weights = emit_weights(self.tensors)
with open(self.output_dir / "hpvm_src.cc", "w") as f:
f.write(
template.render(
......@@ -134,3 +110,30 @@ def make_c_identifier(name: str) -> str:
if name[0].isnumeric():
name = "_" + name
return name
def emit_weights(tensors: Dict[str, Tensor]) -> List[dict]:
ret = []
for name, tensor in tensors.items():
if not isinstance(tensor, WeightTensor):
continue
name = make_c_identifier(name)
file_path = f"{tensor.get_mapped_name()}_path.bin"
ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
return ret
def get_input_args(
input_nodes: List[NodeT], tensors: Dict[str, Tensor]
) -> Dict[str, Tuple[int, bool]]:
# Input to the graph + all weight tensors
# Sometimes these 2 kinds can overlap (due to ONNX optim)
# We'll dedup this array as well.
root_args = []
for i in input_nodes:
root_args.append(i.name)
for tensor in tensors.values():
if isinstance(tensor, WeightTensor):
root_args.append(tensor.name)
root_args = sorted(list(set(root_args)))
return {f_name: (index, True) for index, f_name in enumerate(root_args)}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment