Skip to content
Snippets Groups Projects
Commit cee295fb authored by Lev Zlotnik's avatar Lev Zlotnik
Browse files

Fix issue #155

parent 7fb41d6f
No related branches found
No related tags found
No related merge requests found
...@@ -83,7 +83,16 @@ class SummaryGraph(object): ...@@ -83,7 +83,16 @@ class SummaryGraph(object):
def __init__(self, model, dummy_input): def __init__(self, model, dummy_input):
model = distiller.make_non_parallel_copy(model) model = distiller.make_non_parallel_copy(model)
with torch.onnx.set_training(model, False): with torch.onnx.set_training(model, False):
trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) def _tupletree2device(obj, device):
if isinstance(obj, torch.Tensor):
return obj.to(device)
if not isinstance(obj, tuple):
raise TypeError("obj has to be a tree of tuples.")
return tuple(_tupletree2device(child, device) for child in obj)
device = next(model.parameters()).device
dummy_input = _tupletree2device(dummy_input, device)
trace, _ = jit.get_trace_graph(model, dummy_input)
# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
# composing a GEMM operation; etc. # composing a GEMM operation; etc.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment