From 8ab0c3e92af4bce05372f0bf27b2bf4ab04762cd Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 10 May 2018 11:07:44 +0300 Subject: [PATCH] pytorch 0.4: The type() of a Tensor has changed Following https://pytorch.org/2018/04/22/0_4_0-migration-guide.html, we need to be more precise in how we use .type() --- distiller/pruning/ranked_structures_pruner.py | 2 +- distiller/thresholding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 0976475..683f118 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -48,7 +48,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner): bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] - binary_map = filter_mags.gt(threshold).type(type(param.data)) + binary_map = filter_mags.gt(threshold).type(param.data.type()) expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3)) msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 9e9e834..f18899b 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -106,7 +106,7 @@ class GroupThresholdMixin(object): kernel_means = view_2d.abs().mean(dim=1) k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda() - binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(type(param.data)) + binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) # Now let's expand back up to a 4D mask a = binary_map.expand(num_filters, num_kernels_per_filter) -- GitLab