From cee295fb817d8e022304cb6f20e32e5642e58b4a Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <lev.zlotnik@intel.com> Date: Mon, 4 Mar 2019 16:02:22 +0200 Subject: [PATCH] Fix issue #155 --- distiller/summary_graph.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 4108d54..5f570fb 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -83,7 +83,16 @@ class SummaryGraph(object): def __init__(self, model, dummy_input): model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): - trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) + def _tupletree2device(obj, device): + if isinstance(obj, torch.Tensor): + return obj.to(device) + if not isinstance(obj, tuple): + raise TypeError("obj has to be a tree of tuples.") + return tuple(_tupletree2device(child, device) for child in obj) + + device = next(model.parameters()).device + dummy_input = _tupletree2device(dummy_input, device) + trace, _ = jit.get_trace_graph(model, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. -- GitLab