diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 5f570fb1d696bbdb025e89a5c0799a2fcb1902a0..073bd27b680a8bcb170673cbf61ae881053ec5bb 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -83,15 +83,9 @@ class SummaryGraph(object):
     def __init__(self, model, dummy_input):
         model = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model, False):
-            def _tupletree2device(obj, device):
-                if isinstance(obj, torch.Tensor):
-                    return obj.to(device)
-                if not isinstance(obj, tuple):
-                    raise TypeError("obj has to be a tree of tuples.")
-                return tuple(_tupletree2device(child, device) for child in obj)
-
+            
             device = next(model.parameters()).device
-            dummy_input = _tupletree2device(dummy_input, device)
+            dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
             trace, _ = jit.get_trace_graph(model, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
diff --git a/distiller/utils.py b/distiller/utils.py
index 220ba3ce8a5befa8fbf4565b461cfafeca2f3a39..875b4043e1c8a3cb86037edd9965500db3a74605 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=
     if min_val >= max_val:
         raise ValueError('min_val must be less than max_val')
     return checker
+
+
+def convert_tensors_recursively_to(val, *args, **kwargs):
+    """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same."""
+    if isinstance(val, torch.Tensor):
+        return val.to(*args, **kwargs)
+
+    if isinstance(val, (tuple, list)):
+        return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val)
+
+    return val