diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 5f570fb1d696bbdb025e89a5c0799a2fcb1902a0..073bd27b680a8bcb170673cbf61ae881053ec5bb 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -83,15 +83,9 @@ class SummaryGraph(object): def __init__(self, model, dummy_input): model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): - def _tupletree2device(obj, device): - if isinstance(obj, torch.Tensor): - return obj.to(device) - if not isinstance(obj, tuple): - raise TypeError("obj has to be a tree of tuples.") - return tuple(_tupletree2device(child, device) for child in obj) - + device = next(model.parameters()).device - dummy_input = _tupletree2device(dummy_input, device) + dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) trace, _ = jit.get_trace_graph(model, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes diff --git a/distiller/utils.py b/distiller/utils.py index 220ba3ce8a5befa8fbf4565b461cfafeca2f3a39..875b4043e1c8a3cb86037edd9965500db3a74605 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max= if min_val >= max_val: raise ValueError('min_val must be less than max_val') return checker + + +def convert_tensors_recursively_to(val, *args, **kwargs): + """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same.""" + if isinstance(val, torch.Tensor): + return val.to(*args, **kwargs) + + if isinstance(val, (tuple, list)): + return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val) + + return val