Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
H
hpvm-release
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
hpvm-release
Commits
27c49a85
Commit
27c49a85
authored
4 years ago
by
Yifan Zhao
Browse files
Options
Downloads
Patches
Plain Diff
Use networkx for onnx graph
parent
5f46b4b1
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
hpvm/projects/onnx/frontend/graph_builder.py
+132
-98
132 additions, 98 deletions
hpvm/projects/onnx/frontend/graph_builder.py
with
132 additions
and
98 deletions
hpvm/projects/onnx/frontend/graph_builder.py
+
132
−
98
View file @
27c49a85
...
...
@@ -13,6 +13,7 @@ from tensor import InputTensor, Tensor, WeightTensor
ModelT
=
onnx
.
ModelProto
GraphT
=
onnx
.
GraphProto
NodeT
=
onnx
.
NodeProto
NodeT
.
__hash__
=
lambda
self
:
id
(
self
)
class
GraphBuilder
:
...
...
@@ -43,7 +44,9 @@ class GraphBuilder:
# parse weight
weight_cnt
=
0
for
weight_tensor
in
onnx_graph
.
initializer
:
tensors
[
weight_tensor
.
name
]
=
WeightTensor
(
weight_tensor
,
f
"
weight_
{
weight_cnt
}
"
)
tensors
[
weight_tensor
.
name
]
=
WeightTensor
(
weight_tensor
,
f
"
weight_
{
weight_cnt
}
"
)
weight_cnt
+=
1
# parse input
input_cnt
=
0
...
...
@@ -67,12 +70,14 @@ class DFG(object):
if
len
(
graph
.
output
)
>
1
:
raise
ValueError
(
"
Only single-output graph is supported
"
)
self
.
output
:
str
=
graph
.
output
[
0
].
name
self
.
_onnx_defs
,
self
.
_onnx_uses
=
self
.
def_use
(
graph
.
node
)
self
.
_var_count
=
0
self
.
tensors
=
tensors
self
.
_graph
=
self
.
_build_dfg
(
graph
)
onnx_graph
=
self
.
_build_onnx_dfg
(
graph
)
self
.
_graph
=
self
.
_build_dfg
(
onnx_graph
)
self
.
_dce
()
# Remove unused values
################ Interfaces:
@property
def
traverse_order
(
self
)
->
List
[
g
.
DFGNode
]:
return
list
(
nx
.
topological_sort
(
self
.
_graph
))
...
...
@@ -89,115 +94,74 @@ class DFG(object):
assert
len
(
inputs
)
==
1
return
inputs
[
0
]
@staticmethod
def
def_use
(
nodes
:
Iterable
)
->
Tuple
[
dict
,
dict
]:
"""
Computes def/use relation from a list of node.
################ Internal methods (high-level):
This method is duck-typed and operates on any node defining .input and .output.
"""
defs
,
uses
=
{},
defaultdict
(
list
)
for
n
in
nodes
:
for
input_
in
n
.
input
:
uses
[
input_
].
append
(
n
)
for
output
in
n
.
output
:
defs
[
output
]
=
n
return
defs
,
uses
def
_build_onnx_dfg
(
self
,
graph
:
GraphT
)
->
nx
.
DiGraph
:
"""
Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto.
DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit.
"""
ret_graph
=
nx
.
DiGraph
()
onnx_defs
,
onnx_uses
=
self
.
_def_use
(
graph
.
node
)
ret_graph
.
add_nodes_from
(
graph
.
node
)
for
onnx_value_name
,
use_nodes
in
onnx_uses
.
items
():
if
onnx_value_name
not
in
onnx_defs
:
continue
def_node
=
onnx_defs
[
onnx_value_name
]
for
use_node
in
use_nodes
:
ret_graph
.
add_edge
(
def_node
,
use_node
)
return
ret_graph
def
_build_dfg
(
self
,
onnx_graph
:
nx
.
DiGraph
)
->
nx
.
DiGraph
:
onnx_graph
=
self
.
_detect_flatten
(
onnx_graph
)
# For each onnx node, generate our nodes
ret_graph
=
onnx_graph
.
copy
()
error_nodes
=
[]
for
onnx_node
in
onnx_graph
.
nodes
:
if
isinstance
(
onnx_node
,
g
.
DFGNode
):
continue
our_nodes
=
self
.
_emit_node
(
onnx_node
)
if
our_nodes
is
None
:
error_nodes
.
append
(
onnx_node
)
else
:
replace_node_with_chain_
(
ret_graph
,
onnx_node
,
our_nodes
)
if
error_nodes
:
error_repr
=
[
f
"
{
n
.
name
}
(
{
n
.
op_type
}
)
"
for
n
in
error_nodes
]
if
len
(
error_nodes
)
>
10
:
# Magic number
raise
ValueError
(
f
"
Unsupported operators (first 10):
{
error_repr
[
:
10
]
}
"
)
else
:
raise
ValueError
(
f
"
Unsupported operators:
{
error_repr
}
"
)
return
ret_graph
def
_dce
(
self
):
_
,
uses
=
self
.
def_use
(
self
.
_graph
.
nodes
)
_
,
uses
=
self
.
_
def_use
(
self
.
_graph
.
nodes
)
used_values
=
set
(
uses
.
keys
())
unused_values
=
set
(
self
.
tensors
.
keys
())
-
used_values
for
k
in
unused_values
:
self
.
tensors
.
pop
(
k
)
def
_split_node_args
(
self
,
node1
:
g
.
DFGNode
,
node2
:
g
.
DFGNode
,
input_pos
:
int
=
0
,
pop_pos
:
int
=
-
1
)
->
None
:
varname
=
f
"
conv_
{
self
.
_var_count
}
"
node1
.
input
.
pop
(
pop_pos
)
node1
.
output
=
[
varname
]
node2
.
input
[
input_pos
]
=
varname
self
.
_var_count
+=
1
def
_detect_flatten
(
self
,
graph
:
GraphT
)
->
Tuple
[
Dict
[
str
,
NodeT
],
List
[
g
.
DFGNode
]]:
# Look for a shape-gather-unsqueeze-concat chain
nodes
=
graph
.
node
included_nodes
=
{}
# Name to node
generated_nodes
=
[]
def
get_next_in_chain
(
type_
:
str
,
node
:
Optional
[
NodeT
])
->
Optional
[
NodeT
]:
"""
Get a unique user node of the unique output of Node `node`,
and return it if it has Type `type_`.
Also put this node into `included_nodes`.
"""
if
node
is
None
:
return
None
# propagates None
if
len
(
node
.
output
)
!=
1
:
return
None
# Unique output
users
=
self
.
_onnx_uses
.
get
(
node
.
output
[
0
],
[])
if
len
(
users
)
!=
1
:
return
None
# Unique user of the output
(
user
,)
=
users
if
user
.
op_type
!=
type_
:
return
None
# Correct type
if
user
.
name
in
included_nodes
:
# Part of this chain intersects another chain, so we give up
# TODO: in fact we should remove BOTH chain in this case
return
None
return
user
def
_detect_flatten
(
self
,
graph
:
nx
.
DiGraph
)
->
nx
.
DiGraph
:
"""
Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten.
"""
def
get_def_at_pos
(
node
,
pos
:
int
):
return
self
.
_onnx_defs
[
node
.
input
[
pos
]]
def
add_nodes
(
ns
):
for
n
in
ns
:
included_nodes
[
n
.
name
]
=
n
from_
,
to
=
list
(
graph
.
in_edges
(
node
))[
pos
]
return
from_
for
n
in
nodes
:
if
n
.
op_type
!=
"
Shape
"
:
for
n
ode
in
list
(
graph
.
nodes
)
:
if
n
ode
.
op_type
!=
"
Shape
"
:
continue
ng
=
get_next_in_chain
(
"
Gather
"
,
n
)
ng
=
self
.
get_next_in_chain
(
graph
,
"
Gather
"
,
n
ode
)
# Find the second input argument to Gather (will be a Constant node)
# and take that away as well.
nct
=
get_def_at_pos
(
ng
,
1
)
nu
=
get_next_in_chain
(
"
Unsqueeze
"
,
ng
)
nc
=
get_next_in_chain
(
"
Concat
"
,
nu
)
nr
=
get_next_in_chain
(
"
Reshape
"
,
nc
)
if
nr
is
not
None
:
nodes
=
[
n
,
ng
,
nct
,
nu
,
nc
,
nr
]
add_nodes
(
nodes
)
generated_nodes
.
append
(
g
.
FlattenNode
.
from_onnx_idiom
(
nodes
))
return
included_nodes
,
generated_nodes
def
_build_dfg
(
self
,
graph
:
GraphT
)
->
nx
.
DiGraph
:
error_nodes
,
generated_nodes
=
[],
[]
used_onnx_nodes
,
flatten_nodes
=
self
.
_detect_flatten
(
graph
)
generated_nodes
.
extend
(
flatten_nodes
)
for
onnx_node
in
graph
.
node
:
if
onnx_node
.
name
in
used_onnx_nodes
:
continue
our_node
=
self
.
_emit_node
(
onnx_node
)
if
our_node
is
None
:
error_nodes
.
append
(
onnx_node
)
else
:
generated_nodes
.
extend
(
our_node
)
if
error_nodes
:
error_repr
=
[
f
"
{
n
.
name
}
(
{
n
.
op_type
}
)
"
for
n
in
error_nodes
]
if
len
(
error_nodes
)
>
10
:
# Magic number
raise
ValueError
(
f
"
Unsupported operators (first 10):
{
error_repr
[
:
10
]
}
"
)
else
:
raise
ValueError
(
f
"
Unsupported operators:
{
error_repr
}
"
)
ret_graph
=
nx
.
DiGraph
()
defs
,
uses
=
self
.
def_use
(
generated_nodes
)
ret_graph
.
add_nodes_from
(
generated_nodes
)
for
onnx_value_name
,
use_nodes
in
uses
.
items
():
if
onnx_value_name
not
in
defs
:
nu
=
self
.
get_next_in_chain
(
graph
,
"
Unsqueeze
"
,
ng
)
nc
=
self
.
get_next_in_chain
(
graph
,
"
Concat
"
,
nu
)
nr
=
self
.
get_next_in_chain
(
graph
,
"
Reshape
"
,
nc
)
if
nr
is
None
:
continue
def_
node
=
defs
[
onnx_value_name
]
for
use_node
in
use_
nodes
:
ret_graph
.
add_edge
(
def_
node
,
use
_node
)
return
ret_
graph
node
s
=
[
node
,
ng
,
nct
,
nu
,
nc
,
nr
]
gen_node
=
g
.
FlattenNode
.
from_onnx_idiom
(
nodes
)
graph
=
replace_graph_with_node_
(
graph
,
node
s
,
gen
_node
)
return
graph
# This should be the place where partial evaluation happens
def
_emit_node
(
self
,
onnx_node
:
NodeT
)
->
Optional
[
List
[
g
.
DFGNode
]]:
...
...
@@ -219,9 +183,9 @@ class DFG(object):
# Some tensors may need transposing
attrs
=
node_attr_to_dict
(
onnx_node
)
# We cannot transpose input tensor (need a transpose op)
assert
not
attrs
.
get
(
'
transA
'
,
False
)
assert
not
attrs
.
get
(
"
transA
"
,
False
)
# But we can transpose weight tensor before emitting it
if
attrs
.
get
(
'
transB
'
,
False
):
if
attrs
.
get
(
"
transB
"
,
False
):
weight_tensor
=
self
.
tensors
[
onnx_node
.
input
[
1
]]
assert
isinstance
(
weight_tensor
,
WeightTensor
)
weight_tensor
.
transpose_
()
...
...
@@ -248,3 +212,73 @@ class DFG(object):
if
onnx_node
.
op_type
in
one_to_one_nodes
:
return
[
one_to_one_nodes
[
onnx_node
.
op_type
](
onnx_node
)]
return
None
################ Internal methods (utils):
@staticmethod
def
get_next_in_chain
(
graph
:
nx
.
DiGraph
,
type_
:
str
,
node
:
Optional
[
NodeT
]
)
->
Optional
[
NodeT
]:
"""
Get a unique user node of the unique output of Node `node`,
and return it if it has Type `type_`.
"""
if
node
is
None
or
len
(
node
.
output
)
!=
1
:
return
None
# Propagates None; Unique output
users
=
list
(
graph
.
neighbors
(
node
))
if
len
(
users
)
!=
1
or
users
[
0
].
op_type
!=
type_
:
return
None
# Unique user of the output; Correct type
return
users
[
0
]
def
_split_node_args
(
self
,
node1
:
g
.
DFGNode
,
node2
:
g
.
DFGNode
,
input_pos
:
int
=
0
,
pop_pos
:
int
=
-
1
)
->
None
:
varname
=
f
"
conv_
{
self
.
_var_count
}
"
node1
.
input
.
pop
(
pop_pos
)
node1
.
output
=
[
varname
]
node2
.
input
[
input_pos
]
=
varname
self
.
_var_count
+=
1
@staticmethod
def
_def_use
(
nodes
:
Iterable
)
->
Tuple
[
dict
,
dict
]:
"""
Computes def/use relation from a list of node.
This method is duck-typed and operates on any node defining .input and .output.
"""
defs
,
uses
=
{},
defaultdict
(
list
)
for
n
in
nodes
:
for
input_
in
n
.
input
:
uses
[
input_
].
append
(
n
)
for
output
in
n
.
output
:
defs
[
output
]
=
n
return
defs
,
uses
def
replace_graph_with_node_
(
graph
:
nx
.
DiGraph
,
subgraph
:
Iterable
,
node
)
->
nx
.
DiGraph
:
left_neighbors
,
right_neighbors
=
set
(),
set
()
for
n
in
subgraph
:
left_neighbors
.
update
(
from_
for
from_
,
to
in
graph
.
in_edges
(
n
))
right_neighbors
.
update
(
to
for
from_
,
to
in
graph
.
out_edges
(
n
))
graph
.
remove_node
(
n
)
for
n
in
left_neighbors
:
if
n
in
graph
:
graph
.
add_edge
(
n
,
node
)
for
n
in
right_neighbors
:
if
n
in
graph
:
graph
.
add_edge
(
node
,
n
)
return
graph
def
replace_node_with_chain_
(
graph
:
nx
.
DiGraph
,
node
,
chain
:
Iterable
)
->
nx
.
DiGraph
:
chain
=
list
(
chain
)
if
not
chain
:
graph
.
remove_node
(
node
)
return
graph
for
n1
,
n2
in
zip
(
chain
,
chain
[
1
:]):
graph
.
add_edge
(
n1
,
n2
)
# Add the chain first
for
from_
,
_
in
graph
.
in_edges
(
node
):
graph
.
add_edge
(
from_
,
chain
[
0
])
for
_
,
to
in
graph
.
out_edges
(
node
):
graph
.
add_edge
(
chain
[
-
1
],
to
)
graph
.
remove_node
(
node
)
return
graph
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment