From 3ca2985d20b3de42cc346b186780a84509789834 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 22 Jan 2019 14:22:45 +0200 Subject: [PATCH] SummaryGraph: always convert base the graph on a non-parallel model The use of DataParallel is causing various small problems when used in conjunction with SummaryGraph. The best solution is to force SummaryGraph to use a non-data-parallel version of the model and to always normalize node names when accessing SummaryGraph operations. --- apputils/model_summaries.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 3bb94ae..12f3e03 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -94,6 +94,7 @@ class SummaryGraph(object): Edge = collections.namedtuple('Edge', 'src dst') def __init__(self, model, dummy_input): + model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) @@ -142,6 +143,7 @@ class SummaryGraph(object): self.add_macs_attr() self.add_footprint_attr() self.add_arithmetic_intensity_attr() + del model def __create_op(self, onnx_node): op = {} -- GitLab