diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 097647565a642389768a6cb4487c84100b039667..683f11822fbefb26351e90e9c832731dfcd973c1 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -48,7 +48,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner): bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] - binary_map = filter_mags.gt(threshold).type(type(param.data)) + binary_map = filter_mags.gt(threshold).type(param.data.type()) expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3)) msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 9e9e834a5209fa37737b9d7d8a5812d0231aaf61..f18899ba5a36039c78b46cf4bf5baabbcfdaaa2c 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -106,7 +106,7 @@ class GroupThresholdMixin(object): kernel_means = view_2d.abs().mean(dim=1) k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda() - binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(type(param.data)) + binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) # Now let's expand back up to a 4D mask a = binary_map.expand(num_filters, num_kernels_per_filter)