From 8ab0c3e92af4bce05372f0bf27b2bf4ab04762cd Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 10 May 2018 11:07:44 +0300
Subject: [PATCH] pytorch 0.4: The type() of a Tensor has changed

Following https://pytorch.org/2018/04/22/0_4_0-migration-guide.html, we
need to be more precise in how we use .type()
---
 distiller/pruning/ranked_structures_pruner.py | 2 +-
 distiller/thresholding.py                     | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 0976475..683f118 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -48,7 +48,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
 
         bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
         threshold = bottomk[-1]
-        binary_map = filter_mags.gt(threshold).type(type(param.data))
+        binary_map = filter_mags.gt(threshold).type(param.data.type())
         expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
         zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3))
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 9e9e834..f18899b 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -106,7 +106,7 @@ class GroupThresholdMixin(object):
             kernel_means = view_2d.abs().mean(dim=1)
             k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
             thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
-            binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(type(param.data))
+            binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
 
             # Now let's expand back up to a 4D mask
             a = binary_map.expand(num_filters, num_kernels_per_filter)
-- 
GitLab