diff --git a/distiller/utils.py b/distiller/utils.py index 4676474f441b6dc1801f5f0179240ff98795aedc..9e9375512635c2a7fad9d894b423b733b41589aa 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -108,9 +108,11 @@ def denormalize_module_name(parallel_model, normalized_name): return normalized_name # Did not find a module with the name <normalized_name> -def volume(tensor): +def volume(tensor_desc): """return the volume of a pytorch tensor""" - return np.prod(tensor.shape) + if isinstance(tensor_desc, tuple): + return np.prod(tensor_desc) + return np.prod(tensor_desc.shape) def density(tensor):