From 3fe741cc2d9bd9470cead8885e0a1af6d6812a7f Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 16 May 2018 14:02:40 +0300 Subject: [PATCH] Fix additional 0-dim accesses --- distiller/pruning/sensitivity_pruner.py | 2 +- distiller/quantization/q_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py index 34a951a..624b84d 100755 --- a/distiller/pruning/sensitivity_pruner.py +++ b/distiller/pruning/sensitivity_pruner.py @@ -45,7 +45,7 @@ class SensitivityPruner(_ParameterPruner): def set_param_mask(self, param, param_name, zeros_mask_dict, meta): if not hasattr(param, 'stddev'): - param.stddev = torch.std(param).data[0] + param.stddev = torch.std(param).item() if param_name not in self.sensitivities: if '*' not in self.sensitivities: diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py index 6a32545..efad568 100644 --- a/distiller/quantization/q_utils.py +++ b/distiller/quantization/q_utils.py @@ -55,7 +55,7 @@ def linear_dequantize(input, scale_factor, inplace=False): def get_tensor_max_abs(tensor): - return max(abs(tensor.max().data[0]), abs(tensor.min().data[0])) + return max(abs(tensor.max().item()), abs(tensor.min().item())) def get_quantized_range(num_bits, signed=True): -- GitLab