Skip to content
Snippets Groups Projects
Commit 8ab0c3e9 authored by Neta Zmora's avatar Neta Zmora
Browse files

pytorch 0.4: The type() of a Tensor has changed

Following https://pytorch.org/2018/04/22/0_4_0-migration-guide.html, we
need to be more precise in how we use .type()
parent deb7dd15
No related branches found
No related tags found
No related merge requests found
......@@ -48,7 +48,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
threshold = bottomk[-1]
binary_map = filter_mags.gt(threshold).type(type(param.data))
binary_map = filter_mags.gt(threshold).type(param.data.type())
expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3))
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
......
......@@ -106,7 +106,7 @@ class GroupThresholdMixin(object):
kernel_means = view_2d.abs().mean(dim=1)
k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(type(param.data))
binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
# Now let's expand back up to a 4D mask
a = binary_map.expand(num_filters, num_kernels_per_filter)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment