From 7cecd48d150e58e1e5376de9fee1c05a7b9d9672 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 5 Dec 2018 00:12:47 +0200
Subject: [PATCH] AMC: adjust to latest channel/filter pruning APIs

---
 examples/automated_deep_compression/ADC.py | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py
index 8003341..f4ebe20 100755
--- a/examples/automated_deep_compression/ADC.py
+++ b/examples/automated_deep_compression/ADC.py
@@ -66,6 +66,8 @@ PERFORM_THINNING = True
 #reward = -1 * math.log(total_macs)
 #reward =  -1 * vloss
 
+steps_per_episode = 13  # TODO: this should not be hard-coded
+
 
 def do_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn):
     np.random.seed()
@@ -88,6 +90,7 @@ def coach_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn
         exploration_noise = 0.5
         #exploration_noise = 0.25
         exploitation_decay = 0.996
+        # These parameters are passed to the Distiller environment
         graph_manager.env_params.additional_simulator_parameters = {
             'model': model,
             'dataset': dataset,
@@ -125,8 +128,6 @@ def coach_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn
             #'reward_fn': lambda top1, total_macs: min(top1/100, 0.75)
         }
 
-    #msglogger.debug('Experiment configuarion:\n' + json.dumps(graph_manager.env_params.additional_simulator_parameters, indent=2))
-    steps_per_episode = 13
     agent_params.exploration.noise_percentage_schedule = PieceWiseSchedule([(ConstantSchedule(exploration_noise),
                                                                              EnvironmentSteps(100*steps_per_episode)),
                                                                             (ExponentialSchedule(exploration_noise, 0, exploitation_decay),
@@ -172,7 +173,7 @@ class CNNEnvironment(gym.Env):
         self.desired_reduction = desired_reduction
         self.STATE_EMBEDDING_LEN = len(Observation._fields)
         if self.onehot_encoding:
-            self.STATE_EMBEDDING_LEN += 12
+            self.STATE_EMBEDDING_LEN += (steps_per_episode - 1)
         self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN,))
 
     def reset(self, init_only=False):
@@ -202,12 +203,12 @@ class CNNEnvironment(gym.Env):
             return self.conv_layers[idx]
         except KeyError:
             return None
+
     def episode_is_done(self):
         return self.current_layer_id == self.num_layers()
 
     def remaining_macs(self):
-        """Return the amount of MACs remaining in the model's unprocessed
-        Convolution layers.
+        """Return the amount of MACs remaining in the model's unprocessed Convolution layers.
         This is normalized to the range 0..1
         """
         #return 1 - self.sum_list_macs(self.unprocessed_layers) / self.dense_model_macs
@@ -228,7 +229,7 @@ class CNNEnvironment(gym.Env):
     #     return total_macs
 
     def render(self, mode, close):
-        """Provide some feedback to the user about what's going on
+        """Provide some feedback to the user about what's going on.
         This is invoked by the Agent.
         """
         if self.current_layer_id == 0:
@@ -281,7 +282,7 @@ class CNNEnvironment(gym.Env):
     def step(self, action):
         """Take a step, given an action.
 
-        The action represents the desired sparsity.
+        The action represents the desired density.
         This function is invoked by the Agent.
         """
         msglogger.info("env.step - current_layer_id={} action={}".format(self.current_layer_id, action))
@@ -423,16 +424,17 @@ class CNNEnvironment(gym.Env):
 
         if prune_what == "channels":
             calculate_sparsity = distiller.sparsity_ch
-            reg_regims = {conv_pname: [fraction_to_prune, "Channels"]}
             remove_structures = distiller.remove_channels
+            group_type = "Channels"
         elif prune_what == "filters":
             calculate_sparsity = distiller.sparsity_3D
-            reg_regims = {conv_pname: [fraction_to_prune, "3D"]}
+            group_type = "Filters"
             remove_structures = distiller.remove_filters
         else:
             raise ValueError("unsupported structure {}".format(prune_what))
         # Create a channel-ranking pruner
-        pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", reg_regims)
+        pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", group_type,
+                                                                    fraction_to_prune, conv_pname)
         pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=None)
 
         if (self.zeros_mask_dict[conv_pname].mask is None or
-- 
GitLab