Skip to content
Snippets Groups Projects
Unverified Commit 119f7601 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by GitHub
Browse files

Added distiller.utils.convert_recursively_to (#209)

* Added distiller.utils.convert_recursively_to , replaced _treetuple2device in SummaryGraph with it.

* Renamed to convert_tensors_recursively_to
parent 958a52f6
No related branches found
No related tags found
No related merge requests found
...@@ -83,15 +83,9 @@ class SummaryGraph(object): ...@@ -83,15 +83,9 @@ class SummaryGraph(object):
def __init__(self, model, dummy_input): def __init__(self, model, dummy_input):
model = distiller.make_non_parallel_copy(model) model = distiller.make_non_parallel_copy(model)
with torch.onnx.set_training(model, False): with torch.onnx.set_training(model, False):
def _tupletree2device(obj, device):
if isinstance(obj, torch.Tensor):
return obj.to(device)
if not isinstance(obj, tuple):
raise TypeError("obj has to be a tree of tuples.")
return tuple(_tupletree2device(child, device) for child in obj)
device = next(model.parameters()).device device = next(model.parameters()).device
dummy_input = _tupletree2device(dummy_input, device) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
trace, _ = jit.get_trace_graph(model, dummy_input) trace, _ = jit.get_trace_graph(model, dummy_input)
# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
......
...@@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max= ...@@ -589,3 +589,14 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=
if min_val >= max_val: if min_val >= max_val:
raise ValueError('min_val must be less than max_val') raise ValueError('min_val must be less than max_val')
return checker return checker
def convert_tensors_recursively_to(val, *args, **kwargs):
""" Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same."""
if isinstance(val, torch.Tensor):
return val.to(*args, **kwargs)
if isinstance(val, (tuple, list)):
return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val)
return val
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment