diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 4108d541b1070356aa818284453ef461a9c316d2..5f570fb1d696bbdb025e89a5c0799a2fcb1902a0 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -83,7 +83,16 @@ class SummaryGraph(object):
     def __init__(self, model, dummy_input):
         model = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model, False):
-            trace, _ = jit.get_trace_graph(model, dummy_input.cuda())
+            def _tupletree2device(obj, device):
+                if isinstance(obj, torch.Tensor):
+                    return obj.to(device)
+                if not isinstance(obj, tuple):
+                    raise TypeError("obj has to be a tree of tuples.")
+                return tuple(_tupletree2device(child, device) for child in obj)
+
+            device = next(model.parameters()).device
+            dummy_input = _tupletree2device(dummy_input, device)
+            trace, _ = jit.get_trace_graph(model, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.