diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index b1c04a7a8f67d1dec4bb0460986d2d20c0ad8ca8..a650e54772164357816a15f33f13a842a8286874 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -60,7 +60,7 @@ class SummaryGraph(object): """ def __init__(self, model, dummy_input): with torch.onnx.set_training(model, False): - trace, _ = jit.trace(model, dummy_input) + trace, _ = jit.get_trace_graph(model, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. @@ -79,7 +79,7 @@ class SummaryGraph(object): op = {} op['name'] = node.scopeName() op['orig-name'] = node.scopeName() - op['type'] = node.kind() + op['type'] = node.kind().lstrip('::onnx') op['inputs'] = [] op['outputs'] = [] op['params'] = []