graph_builder.py 9.47 KiB
import sys
from onnx import numpy_helper
from graph_ir import Node
from common import InputTensor, WeightTensor
support_onnx_ops = {"DepthwiseConv" : [2],
"Conv" : [2], # only 2d supported here
"MatMul" : None,
"MaxPool": [2], # only 2d supported here
"Activation" : None,
"BatchNormalization" : None,
"Flatten" : None,
"Add" : None,
"Relu" : None,
"Softmax" : None,
"Identity": None,
"Pad": None,
"AveragePool": None,
"Tanh": None}
class GraphBuilder(object):
def __init__(self, model, shape, dtype, weight_dir):
self._check_model(model)
self._check_ops(model)
self.model = model
self.dtype = dtype
self.graph = model.graph
self.weight_dir = weight_dir
self.shape = shape if shape else self._build_shape()
self.tensors = dict()
################################################
# Aux functions for graph building
################################################
def _check_model(self, onnx_model):
try:
from onnx import checker, onnx_cpp2py_export
if hasattr(checker, 'check_model'):
# try use onnx's own model checker before converting any model
try:
checker.check_model(onnx_model)
print("onnx model is checked valid!")
except onnx_cpp2py_export.checker.ValidationError as e:
import warnings
warnings.warn(str(e))
except ImportError as e:
raise ImportError(
"Unable to import onnx.checker which is required {}".format(e))
def _check_ops(self, model):
unsupport = dict()
for node in model.graph.node:
if node.op_type not in support_onnx_ops:
if node.op_type not in unsupport:
unsupport[node.op_type] = 1
else:
unsupport[node.op_type] += 1
if len(unsupport) != 0:
print(sorted(unsupport.items(), key=lambda x: x[1], reverse=True))
raise ValueError(
"Above operator(s) not currently supported! Compilation Aborted.")
def _build_shape(self):
shape = {}
for input in self.graph.input:
# get type of input tensor
tensor_type = input.type.tensor_type
# check if it has a shape:
if (tensor_type.HasField("shape")):
shape[input.name] = tensor_type.shape
return shape
def _parse_array(self, tensor_proto):
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError(
"Unable to import onnx which is required {}".format(e))
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
return np_array
def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
try:
name = value_proto.name
except AttributeError:
name = value_proto
return name
def _parse_dtype(self, value_proto, dtype):
"""Parse dtype."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
except AttributeError:
return dtype
def _support_check(self, node):
op_name = node.op_type
#print(op_name)
if op_name not in support_onnx_ops:
return False
else:
if support_onnx_ops[op_name] == None:
return True
else:
#print(type(node.attribute))
for attr in node.attribute:
# partially evaluate the kernel shape
if attr.name == 'kernel_shape':
# TODO: not assume all kernel shape is in INTS
return len(attr.ints) in support_onnx_ops[op_name]
return False
def _dump_weight(self, weight_tensor):
print("Dump weight: {0}".format(weight_tensor.name))
################################################
# Top level Graph Building functions
# return the compilation-ready graph
################################################
def build_graph(self):
# parse weight
weight_cnt = 0
for weight_tensor in self.graph.initializer:
self.tensors[weight_tensor.name] = WeightTensor(weight_tensor)
self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
weight_cnt += 1
# parse input
for i in self.graph.input:
if i.name not in self.tensors:
self.tensors[i.name] = InputTensor(i.name)
# FIXME: This input name is hardcoded
self.tensors[i.name].set_mapped_name("input")
# parse intermediate tensor
for node in self.graph.node:
op_name = node.op_type
#print("###############################")
if not self._support_check(node):
raise ValueError(
"Operator not currently supported: `{0}`!".format(op_name))
#print("attribute: " + str(node.attribute))
#print("input: " + str(node.input))
#print("output: " + str(node.output))
#print("###############################")
for i in node.input:
if i not in self.tensors:
raise ValueError(
"Compilation Interrupted for missing input!`{0}`.".format(i))
for i in node.output:
if i not in self.tensors:
self.tensors[i] = InputTensor(i)
# Dump weights
for tensor in self.tensors.values():
if isinstance(tensor, WeightTensor):
self._dump_weight(tensor)
return DFG(self.graph, self.tensors)
class DFG(object):
root_set = False
def __init__(self, graph, tensors):
self.graph = graph
self.tensors = tensors
def hasSingleInput(self, layer):
layer_name = layer.__class__.__name__
return layer_name in self.singleInputLayers
def hasMultipleInputs(self, layer):
layer_name = layer.__class__.__name__
return layer_name in self.multiInputLayers
def add_dfg_edge(self, inbound_node_name, dfg_node):
inbound_node_name = inbound_node_name.split(":")[0]
inbound_node_name = inbound_node_name.split("/")[0]
if inbound_node_name in self.node_map:
inbound_node = self.node_map[inbound_node_name]
print(inbound_node_name, " found!")
inbound_node.add_output(dfg_node)
dfg_node.add_input(inbound_node)
else:
print("--inbound node NOT FOUND!")
def add_to_graph(self, layer):
dfg_node = DFGNode(layer)
if not self.root_set:
self.root_node = dfg_node
self.root_set = True # DFG root node is now set
if self.hasMultipleInputs(layer):
for j in range(len(layer.input)):
print(type(layer.input[j]))
print(layer.input[j].op.name)
self.add_dfg_edge(layer.input[j].op.name, dfg_node)
else:
print(layer.input.name)
self.add_dfg_edge(layer.input.name, dfg_node)
# Adding DFG node to name mapping
self.node_map[layer.name] = dfg_node
# Check if all predecessor nodes have been visited thus far - reverse
# postorder traversal
def predVisited(self, cur_node, visited_nodes):
for input_node in cur_node.inputs:
if input_node.layer_name not in visited_nodes:
return False
# All predecessors are visited
return True
def traverseNode(self, cur_node, visited_nodes):
# Skip visited nodes
if cur_node.layer_name in visited_nodes:
return
if self.predVisited(cur_node, visited_nodes):
print(cur_node.layer_type)
print(cur_node.layer_name)
visited_nodes[cur_node.layer_name] = True
# Invoking traversal on outbound nodes
for output_node in cur_node.outputs:
self.traverseNode(output_node, visited_nodes)
# NOTE: Assuming that no outbound edges implies the last node in
# the graph
if len(cur_node.outputs) == 0:
self.last_node = cur_node
# Build and Print the DFG in reverse postorder
def buildDFG(self):
print("\n\n ****** Traversing and Printing DFG ******* \n\n")
visited_nodes = {}
# Starting traversal at the DFG root node
self.traverseNode(self.root_node, visited_nodes)
# This should be the place where partial evaluation happens
def emitNode(self, layer):
if layer.op_type == "Conv":
return Conv2DNode()
elif layer.op_type == "Tanh":
pass
elif layer.op_type == "MaxPool":
pass
elif layer.op_type == "Flatten":
pass
elif layer.op_type == "MatMul":
pass
elif layer.op_type == "Add":
pass
elif layer.op_type == "SoftMax":
pass
elif layer.op_type == "Identity":
pass
else:
raise ValueError("Unsupported operator type!")
sys.exit("Unsupported operator type!")