diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 3bb94aea9850a74cd0cdd03291a2816ea0ceace0..12f3e03562e7a35357cc456d013a3d952d41e0be 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -94,6 +94,7 @@ class SummaryGraph(object): Edge = collections.namedtuple('Edge', 'src dst') def __init__(self, model, dummy_input): + model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) @@ -142,6 +143,7 @@ class SummaryGraph(object): self.add_macs_attr() self.add_footprint_attr() self.add_arithmetic_intensity_attr() + del model def __create_op(self, onnx_node): op = {}