From 7cecd48d150e58e1e5376de9fee1c05a7b9d9672 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 5 Dec 2018 00:12:47 +0200 Subject: [PATCH] AMC: adjust to latest channel/filter pruning APIs --- examples/automated_deep_compression/ADC.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py index 8003341..f4ebe20 100755 --- a/examples/automated_deep_compression/ADC.py +++ b/examples/automated_deep_compression/ADC.py @@ -66,6 +66,8 @@ PERFORM_THINNING = True #reward = -1 * math.log(total_macs) #reward = -1 * vloss +steps_per_episode = 13 # TODO: this should not be hard-coded + def do_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn): np.random.seed() @@ -88,6 +90,7 @@ def coach_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn exploration_noise = 0.5 #exploration_noise = 0.25 exploitation_decay = 0.996 + # These parameters are passed to the Distiller environment graph_manager.env_params.additional_simulator_parameters = { 'model': model, 'dataset': dataset, @@ -125,8 +128,6 @@ def coach_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn #'reward_fn': lambda top1, total_macs: min(top1/100, 0.75) } - #msglogger.debug('Experiment configuarion:\n' + json.dumps(graph_manager.env_params.additional_simulator_parameters, indent=2)) - steps_per_episode = 13 agent_params.exploration.noise_percentage_schedule = PieceWiseSchedule([(ConstantSchedule(exploration_noise), EnvironmentSteps(100*steps_per_episode)), (ExponentialSchedule(exploration_noise, 0, exploitation_decay), @@ -172,7 +173,7 @@ class CNNEnvironment(gym.Env): self.desired_reduction = desired_reduction self.STATE_EMBEDDING_LEN = len(Observation._fields) if self.onehot_encoding: - self.STATE_EMBEDDING_LEN += 12 + self.STATE_EMBEDDING_LEN += (steps_per_episode - 1) self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN,)) def reset(self, init_only=False): @@ -202,12 +203,12 @@ class CNNEnvironment(gym.Env): return self.conv_layers[idx] except KeyError: return None + def episode_is_done(self): return self.current_layer_id == self.num_layers() def remaining_macs(self): - """Return the amount of MACs remaining in the model's unprocessed - Convolution layers. + """Return the amount of MACs remaining in the model's unprocessed Convolution layers. This is normalized to the range 0..1 """ #return 1 - self.sum_list_macs(self.unprocessed_layers) / self.dense_model_macs @@ -228,7 +229,7 @@ class CNNEnvironment(gym.Env): # return total_macs def render(self, mode, close): - """Provide some feedback to the user about what's going on + """Provide some feedback to the user about what's going on. This is invoked by the Agent. """ if self.current_layer_id == 0: @@ -281,7 +282,7 @@ class CNNEnvironment(gym.Env): def step(self, action): """Take a step, given an action. - The action represents the desired sparsity. + The action represents the desired density. This function is invoked by the Agent. """ msglogger.info("env.step - current_layer_id={} action={}".format(self.current_layer_id, action)) @@ -423,16 +424,17 @@ class CNNEnvironment(gym.Env): if prune_what == "channels": calculate_sparsity = distiller.sparsity_ch - reg_regims = {conv_pname: [fraction_to_prune, "Channels"]} remove_structures = distiller.remove_channels + group_type = "Channels" elif prune_what == "filters": calculate_sparsity = distiller.sparsity_3D - reg_regims = {conv_pname: [fraction_to_prune, "3D"]} + group_type = "Filters" remove_structures = distiller.remove_filters else: raise ValueError("unsupported structure {}".format(prune_what)) # Create a channel-ranking pruner - pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", reg_regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", group_type, + fraction_to_prune, conv_pname) pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=None) if (self.zeros_mask_dict[conv_pname].mask is None or -- GitLab