From 3fe741cc2d9bd9470cead8885e0a1af6d6812a7f Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Wed, 16 May 2018 14:02:40 +0300
Subject: [PATCH] Fix additional 0-dim accesses

---
 distiller/pruning/sensitivity_pruner.py | 2 +-
 distiller/quantization/q_utils.py       | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py
index 34a951a..624b84d 100755
--- a/distiller/pruning/sensitivity_pruner.py
+++ b/distiller/pruning/sensitivity_pruner.py
@@ -45,7 +45,7 @@ class SensitivityPruner(_ParameterPruner):
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         if not hasattr(param, 'stddev'):
-            param.stddev = torch.std(param).data[0]
+            param.stddev = torch.std(param).item()
 
         if param_name not in self.sensitivities:
             if '*' not in self.sensitivities:
diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py
index 6a32545..efad568 100644
--- a/distiller/quantization/q_utils.py
+++ b/distiller/quantization/q_utils.py
@@ -55,7 +55,7 @@ def linear_dequantize(input, scale_factor, inplace=False):
 
 
 def get_tensor_max_abs(tensor):
-    return max(abs(tensor.max().data[0]), abs(tensor.min().data[0]))
+    return max(abs(tensor.max().item()), abs(tensor.min().item()))
 
 
 def get_quantized_range(num_bits, signed=True):
-- 
GitLab