diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index ea11aa106cc140ecd101aa1985df3e940881a101..421220ad3fa4c50696c30eb336f49d5569d0254c 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -342,14 +342,18 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner): # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] - #distiller.assign_layer_fq_names(model) + distiller.assign_layer_fq_names(model) module = distiller.find_module_by_fq_name(model, fq_name) - if module is None: - raise ValueError("Could not find a layer named %s in the model." - "\nMake sure to use assign_layer_fq_names()" % fq_name) + assert module is not None + if not hasattr(module, self.activation_rank_criterion): - raise ValueError("Could not find attribute \"%s\" in module %s" - "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" % + raise ValueError("Could not find attribute \"%s\" in module %s\n" + "\tThis is pruner uses activation statistics collected during forward-" + "passes of the network.\n" + "\tThis error is an indication that these statistics " + "have not been collected yet.\n" + "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n" + "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)"% (self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr(module, self.activation_rank_criterion).value()