From bd57f8ad934abfe61214d117877e62d3d0519d78 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 28 Apr 2020 02:14:30 +0300
Subject: [PATCH] Improve error message when using
 ActivationAPoZRankedFilterPruner

See issue #444
---
 distiller/pruning/ranked_structures_pruner.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index ea11aa1..421220a 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -342,14 +342,18 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner):
 
         # Use the parameter name to locate the module that has the activation sparsity statistics
         fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
-        #distiller.assign_layer_fq_names(model)
+        distiller.assign_layer_fq_names(model)
         module = distiller.find_module_by_fq_name(model, fq_name)
-        if module is None:
-            raise ValueError("Could not find a layer named %s in the model."
-                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
+        assert module is not None
+
         if not hasattr(module, self.activation_rank_criterion):
-            raise ValueError("Could not find attribute \"%s\" in module %s"
-                             "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" %
+            raise ValueError("Could not find attribute \"%s\" in module %s\n"
+                             "\tThis is pruner uses activation statistics collected during forward-"
+                             "passes of the network.\n"
+                             "\tThis error is an indication that these statistics "
+                             "have not been collected yet.\n"
+                             "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n"
+                             "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)"%
                              (self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
 
         quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
-- 
GitLab