From cee295fb817d8e022304cb6f20e32e5642e58b4a Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <lev.zlotnik@intel.com>
Date: Mon, 4 Mar 2019 16:02:22 +0200
Subject: [PATCH] Fix issue #155

---
 distiller/summary_graph.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 4108d54..5f570fb 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -83,7 +83,16 @@ class SummaryGraph(object):
     def __init__(self, model, dummy_input):
         model = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model, False):
-            trace, _ = jit.get_trace_graph(model, dummy_input.cuda())
+            def _tupletree2device(obj, device):
+                if isinstance(obj, torch.Tensor):
+                    return obj.to(device)
+                if not isinstance(obj, tuple):
+                    raise TypeError("obj has to be a tree of tuples.")
+                return tuple(_tupletree2device(child, device) for child in obj)
+
+            device = next(model.parameters()).device
+            dummy_input = _tupletree2device(dummy_input, device)
+            trace, _ = jit.get_trace_graph(model, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
-- 
GitLab