From 119f7601436208a920415d481ff9dd0f1b2d2ae9 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Thu, 28 Mar 2019 16:08:05 +0200 Subject: [PATCH] Added distiller.utils.convert_recursively_to (#209) * Added distiller.utils.convert_recursively_to , replaced _treetuple2device in SummaryGraph with it. * Renamed to convert_tensors_recursively_to --- distiller/summary_graph.py | 10 ++-------- distiller/utils.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 5f570fb..073bd27 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -83,15 +83,9 @@ class SummaryGraph(object): def __init__(self, model, dummy_input): model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): - def _tupletree2device(obj, device): - if isinstance(obj, torch.Tensor): - return obj.to(device) - if not isinstance(obj, tuple): - raise TypeError("obj has to be a tree of tuples.") - return tuple(_tupletree2device(child, device) for child in obj) - + device = next(model.parameters()).device - dummy_input = _tupletree2device(dummy_input, device) + dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) trace, _ = jit.get_trace_graph(model, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes diff --git a/distiller/utils.py b/distiller/utils.py index 220ba3c..875b404 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max= if min_val >= max_val: raise ValueError('min_val must be less than max_val') return checker + + +def convert_tensors_recursively_to(val, *args, **kwargs): + """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same.""" + if isinstance(val, torch.Tensor): + return val.to(*args, **kwargs) + + if isinstance(val, (tuple, list)): + return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val) + + return val -- GitLab