diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 4108d541b1070356aa818284453ef461a9c316d2..5f570fb1d696bbdb025e89a5c0799a2fcb1902a0 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -83,7 +83,16 @@ class SummaryGraph(object): def __init__(self, model, dummy_input): model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): - trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) + def _tupletree2device(obj, device): + if isinstance(obj, torch.Tensor): + return obj.to(device) + if not isinstance(obj, tuple): + raise TypeError("obj has to be a tree of tuples.") + return tuple(_tupletree2device(child, device) for child in obj) + + device = next(model.parameters()).device + dummy_input = _tupletree2device(dummy_input, device) + trace, _ = jit.get_trace_graph(model, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc.