From 735fdd0b5596c3c726ae7ad115edf894d4bf8730 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 1 Nov 2018 15:19:39 +0200 Subject: [PATCH] Bug fix: fix crash when generating PNG of network --- distiller/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distiller/utils.py b/distiller/utils.py index 4676474..9e93755 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -108,9 +108,11 @@ def denormalize_module_name(parallel_model, normalized_name): return normalized_name # Did not find a module with the name <normalized_name> -def volume(tensor): +def volume(tensor_desc): """return the volume of a pytorch tensor""" - return np.prod(tensor.shape) + if isinstance(tensor_desc, tuple): + return np.prod(tensor_desc) + return np.prod(tensor_desc.shape) def density(tensor): -- GitLab