Skip to content
Snippets Groups Projects
Commit 384f4740 authored by Neta Zmora's avatar Neta Zmora
Browse files

SummaryGraph: small changes for ONNX in pytorch 0.4

parent b9bf4282
No related branches found
No related tags found
No related merge requests found
...@@ -60,7 +60,7 @@ class SummaryGraph(object): ...@@ -60,7 +60,7 @@ class SummaryGraph(object):
""" """
def __init__(self, model, dummy_input): def __init__(self, model, dummy_input):
with torch.onnx.set_training(model, False): with torch.onnx.set_training(model, False):
trace, _ = jit.trace(model, dummy_input) trace, _ = jit.get_trace_graph(model, dummy_input)
# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
# composing a GEMM operation; etc. # composing a GEMM operation; etc.
...@@ -79,7 +79,7 @@ class SummaryGraph(object): ...@@ -79,7 +79,7 @@ class SummaryGraph(object):
op = {} op = {}
op['name'] = node.scopeName() op['name'] = node.scopeName()
op['orig-name'] = node.scopeName() op['orig-name'] = node.scopeName()
op['type'] = node.kind() op['type'] = node.kind().lstrip('::onnx')
op['inputs'] = [] op['inputs'] = []
op['outputs'] = [] op['outputs'] = []
op['params'] = [] op['params'] = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment