From 3ca2985d20b3de42cc346b186780a84509789834 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 22 Jan 2019 14:22:45 +0200
Subject: [PATCH] SummaryGraph: always convert base the graph on a non-parallel
 model

The use of DataParallel is causing various small problems when used
in conjunction with SummaryGraph.
The best solution is to force SummaryGraph to use a non-data-parallel
version of the model and to always normalize node names when accessing
SummaryGraph operations.
---
 apputils/model_summaries.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index 3bb94ae..12f3e03 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -94,6 +94,7 @@ class SummaryGraph(object):
     Edge = collections.namedtuple('Edge', 'src dst')
 
     def __init__(self, model, dummy_input):
+        model = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model, False):
             trace, _ = jit.get_trace_graph(model, dummy_input.cuda())
 
@@ -142,6 +143,7 @@ class SummaryGraph(object):
         self.add_macs_attr()
         self.add_footprint_attr()
         self.add_arithmetic_intensity_attr()
+        del model
 
     def __create_op(self, onnx_node):
         op = {}
-- 
GitLab