From 735fdd0b5596c3c726ae7ad115edf894d4bf8730 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 1 Nov 2018 15:19:39 +0200
Subject: [PATCH] Bug fix: fix crash when generating PNG of network

---
 distiller/utils.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/distiller/utils.py b/distiller/utils.py
index 4676474..9e93755 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -108,9 +108,11 @@ def denormalize_module_name(parallel_model, normalized_name):
         return normalized_name   # Did not find a module with the name <normalized_name>
 
 
-def volume(tensor):
+def volume(tensor_desc):
     """return the volume of a pytorch tensor"""
-    return np.prod(tensor.shape)
+    if isinstance(tensor_desc, tuple):
+        return np.prod(tensor_desc)
+    return np.prod(tensor_desc.shape)
 
 
 def density(tensor):
-- 
GitLab