diff --git a/distiller/utils.py b/distiller/utils.py
index 4676474f441b6dc1801f5f0179240ff98795aedc..9e9375512635c2a7fad9d894b423b733b41589aa 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -108,9 +108,11 @@ def denormalize_module_name(parallel_model, normalized_name):
         return normalized_name   # Did not find a module with the name <normalized_name>
 
 
-def volume(tensor):
+def volume(tensor_desc):
     """return the volume of a pytorch tensor"""
-    return np.prod(tensor.shape)
+    if isinstance(tensor_desc, tuple):
+        return np.prod(tensor_desc)
+    return np.prod(tensor_desc.shape)
 
 
 def density(tensor):