Skip to content
Snippets Groups Projects
Commit 31663f3f authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed gemm transpose issue

parent ee8ea8af
No related branches found
No related tags found
No related merge requests found
from collections import defaultdict
from onnx_attr import node_attr_to_dict
from os import PathLike
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple
......@@ -213,8 +214,16 @@ class DFG(object):
self._allocate_insert_var(conv_node, bias_node)
return [conv_node, bias_node]
elif onnx_node.op_type in ("MatMul", "Gemm"):
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
if onnx_node.op_type == "Gemm":
# Some tensors may need transposing
attrs = node_attr_to_dict(onnx_node)
# We cannot transpose input tensor (need a transpose op)
assert not attrs.get('transA', False)
# But we can transpose weight tensor before emitting it
if attrs.get('transB', False):
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
weight_tensor.transpose_()
if len(onnx_node.input) == 2:
return [g.MatMulNode(onnx_node)]
else:
......
from typing import Tuple
import numpy as np
from onnx import AttributeProto, NodeProto, TensorProto
def throw_ctor(ty):
def _throw_ctor(x):
raise ValueError(f"Cannot construct type {ty} from value {x}")
return _throw_ctor
def composite_ctor(singular_ctor):
def _composite_ctor(xs):
return [singular_ctor(x) for x in xs]
return _composite_ctor
def tensor_ctor(x: TensorProto) -> np.ndarray:
import numpy as np
tensor_typenames_to_numpy_ty = {
"BOOL": np.bool,
"UINT8": np.uint8,
"INT8": np.int8,
"UINT16": np.uint16,
"INT16": np.int16,
"INT32": np.int32,
"UINT32": np.uint32,
"INT64": np.int64,
"UINT64": np.uint64,
"STRING": np.str,
"FLOAT16": np.float16,
"FLOAT": np.float32,
"DOUBLE": np.float64,
# 'BFLOAT16' -- unsupported
"COMPLEX64": np.complex64,
"COMPLEX128": np.complex128,
}
get_tensor_typename = TensorProto.DataType.Name
tensor_typename = get_tensor_typename(x.data_type)
if tensor_typename not in tensor_typenames_to_numpy_ty:
raise ValueError(f"Tensor with type {tensor_typename} cannot be processed")
numpy_dtype = tensor_typenames_to_numpy_ty[tensor_typename]
numpy_arr = np.frombuffer(x.raw_data, dtype=numpy_dtype).reshape(x.dims).copy()
return numpy_arr
def parse_node_attr(onnx_attr: AttributeProto) -> Tuple[str, object]:
from collections import namedtuple
AttrMeta = namedtuple("AttrMeta", ["ctor", "data_name"])
attr_typenames_to_meta = {
"FLOAT": AttrMeta(float, "f"),
"INT": AttrMeta(int, "i"),
"STRING": AttrMeta(str, "s"),
"TENSOR": AttrMeta(tensor_ctor, "t"),
"GRAPH": AttrMeta(throw_ctor("GRAPH"), "g"),
"SPARSE_TENSOR": AttrMeta(throw_ctor("SPARSE_TENSOR"), "sparse_tensor"),
"FLOATS": AttrMeta(composite_ctor(float), "floats"),
"INTS": AttrMeta(composite_ctor(int), "ints"),
"STRINGS": AttrMeta(composite_ctor(str), "strings"),
"TENSORS": AttrMeta(composite_ctor(tensor_ctor), "tensors"),
"GRAPHS": AttrMeta(throw_ctor("GRAPHS"), "graphs"),
"SPARSE_TENSORS": AttrMeta(throw_ctor("SPARSE_TENSORS"), "sparse_tensors"),
}
get_attr_typename = AttributeProto.AttributeType.Name
typename = get_attr_typename(onnx_attr.type)
assert (
typename in attr_typenames_to_meta
), f"ONNX attribute contains non-ONNX type {typename}"
attr_meta = attr_typenames_to_meta[typename]
data = getattr(onnx_attr, attr_meta.data_name)
parsed_data = attr_meta.ctor(data)
return parsed_data
def node_attr_to_dict(onnx_node: NodeProto):
return {attr.name: parse_node_attr(attr) for attr in onnx_node.attribute}
......@@ -46,3 +46,9 @@ class WeightTensor(Tensor):
def dump_weight(self, file_name: PathLike):
self.input_data.tofile(file_name)
def transpose_(self):
if len(self.input_data.shape) != 2:
raise ValueError("Can only transpose 2D array")
self.input_data = self.input_data.T
self.shape[3], self.shape[2] = self.shape[2:]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment