From 119f7601436208a920415d481ff9dd0f1b2d2ae9 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Thu, 28 Mar 2019 16:08:05 +0200
Subject: [PATCH] Added distiller.utils.convert_recursively_to (#209)

* Added distiller.utils.convert_recursively_to , replaced _treetuple2device in SummaryGraph with it.

* Renamed to convert_tensors_recursively_to
---
 distiller/summary_graph.py | 10 ++--------
 distiller/utils.py         | 11 +++++++++++
 2 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 5f570fb..073bd27 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -83,15 +83,9 @@ class SummaryGraph(object):
     def __init__(self, model, dummy_input):
         model = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model, False):
-            def _tupletree2device(obj, device):
-                if isinstance(obj, torch.Tensor):
-                    return obj.to(device)
-                if not isinstance(obj, tuple):
-                    raise TypeError("obj has to be a tree of tuples.")
-                return tuple(_tupletree2device(child, device) for child in obj)
-
+            
             device = next(model.parameters()).device
-            dummy_input = _tupletree2device(dummy_input, device)
+            dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
             trace, _ = jit.get_trace_graph(model, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
diff --git a/distiller/utils.py b/distiller/utils.py
index 220ba3c..875b404 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=
     if min_val >= max_val:
         raise ValueError('min_val must be less than max_val')
     return checker
+
+
+def convert_tensors_recursively_to(val, *args, **kwargs):
+    """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same."""
+    if isinstance(val, torch.Tensor):
+        return val.to(*args, **kwargs)
+
+    if isinstance(val, (tuple, list)):
+        return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val)
+
+    return val
-- 
GitLab