diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py index 215f6c25fb993f8151d5e9ba1c2b5bd8518e919b..80bca3a42807daecaafeb74e3d28937943822e61 100755 --- a/examples/automated_deep_compression/ADC.py +++ b/examples/automated_deep_compression/ADC.py @@ -226,14 +226,14 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn # Create a dictionary of parameters that Coach will handover to DistillerWrapperEnvironment # Once it creates it. services = distiller.utils.MutableNamedTuple({ - 'validate_fn': validate_fn, - 'save_checkpoint_fn': save_checkpoint_fn, - 'train_fn': train_fn}) + 'validate_fn': validate_fn, + 'save_checkpoint_fn': save_checkpoint_fn, + 'train_fn': train_fn}) app_args = distiller.utils.MutableNamedTuple({ - 'dataset': dataset, - 'arch': arch, - 'optimizer_data': optimizer_data}) + 'dataset': dataset, + 'arch': arch, + 'optimizer_data': optimizer_data}) amc_cfg = distiller.utils.MutableNamedTuple({ 'protocol': args.amc_protocol, @@ -243,7 +243,8 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn 'action_range': action_range, 'conv_cnt': conv_cnt, 'reward_frequency': args.amc_reward_frequency, - 'ft_frequency': args.amc_ft_frequency}) + 'ft_frequency': args.amc_ft_frequency, + 'pruning_pattern': "filters"}) # "channels"}) # #net_wrapper = NetworkWrapper(model, app_args, services) #return sample_networks(net_wrapper, services) @@ -344,9 +345,21 @@ resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight "module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight", "module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"] +plain20_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight", + "module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight", + "module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight", + "module.layer2.0.conv1.weight", "module.layer2.0.conv2.weight", + "module.layer2.1.conv1.weight", "module.layer2.1.conv2.weight", + "module.layer2.2.conv1.weight", "module.layer2.2.conv2.weight", + "module.layer3.0.conv1.weight", "module.layer3.0.conv2.weight", + "module.layer3.1.conv1.weight", "module.layer3.1.conv2.weight", + "module.layer3.2.conv1.weight", "module.layer3.2.conv2.weight"] + + resnet50_layers = [param[:-len(".weight")] for param in resnet50_params] resnet20_layers = [param[:-len(".weight")] for param in resnet20_params] resnet56_layers = [param[:-len(".weight")] for param in resnet56_params] +plain20_layers = [param[:-len(".weight")] for param in plain20_params] class NetworkWrapper(object): @@ -375,6 +388,8 @@ class NetworkWrapper(object): resnet_layers = resnet56_layers elif self.app_args.arch == "resnet50": resnet_layers = resnet50_layers + elif self.app_args.arch == "plain20_cifar": + resnet_layers = plain20_layers return collect_conv_details(model, self.app_args.dataset, True, resnet_layers) def num_layers(self): @@ -426,16 +441,16 @@ class NetworkWrapper(object): conv_pname = layer.name + ".weight" conv_p = distiller.model_find_param(self.model, conv_pname) - msglogger.info("ADC: removing %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname)) + msglogger.info("ADC: trying to remove %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname)) if prune_what == "channels": calculate_sparsity = distiller.sparsity_ch - remove_structures = distiller.remove_channels + remove_structures_fn = distiller.remove_channels group_type = "Channels" elif prune_what == "filters": calculate_sparsity = distiller.sparsity_3D group_type = "Filters" - remove_structures = distiller.remove_filters + remove_structures_fn = distiller.remove_filters else: raise ValueError("unsupported structure {}".format(prune_what)) # Create a channel-ranking pruner @@ -446,12 +461,12 @@ class NetworkWrapper(object): if (self.zeros_mask_dict[conv_pname].mask is None or calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0): - msglogger.info("remove_structures: aborting because there are no channels to prune") + msglogger.info("remove_structures: aborting because there are no structures to prune") return 0 # Use the mask to prune self.zeros_mask_dict[conv_pname].apply_mask(conv_p) - remove_structures(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None) + remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None) conv_p = distiller.model_find_param(self.model, conv_pname) return 1 - (self.get_layer_macs(layer) / macs_before) @@ -476,8 +491,6 @@ class NetworkWrapper(object): class DistillerWrapperEnvironment(gym.Env): - metadata = {'render.modes': ['human']} - def __init__(self, model, app_args, amc_cfg, services): self.pylogger = distiller.data_loggers.PythonLogger(msglogger) self.tflogger = distiller.data_loggers.TensorBoardLogger(msglogger.logdir) @@ -573,23 +586,26 @@ class DistillerWrapperEnvironment(gym.Env): def step(self, pruning_action): """Take a step, given an action. - The action represents the desired sparsity. + The action represents the desired sparsity for the "current" layer. This function is invoked by the Agent. """ msglogger.info("env.step - current_layer_id={} episode={}".format(self.current_layer_id, self.episode)) + pruning_action = pruning_action[0] msglogger.info("\tAgent pruning_action={}".format(pruning_action)) + self.agent_action_history.append(pruning_action) if is_using_continuous_action_space(self.amc_cfg.agent_algo): - pruning_action = np.clip(pruning_action[0], self.action_low, self.action_high) + pruning_action = np.clip(pruning_action, self.action_low, self.action_high) else: # Divide the action space into 10 discrete levels (0%, 10%, 20%,....90% sparsity) pruning_action = pruning_action / 10 msglogger.info("\tAgent clipped pruning_action={}".format(pruning_action)) - self.agent_action_history.append(pruning_action) + if self.amc_cfg.action_constrain_fn is not None: pruning_action = self.amc_cfg.action_constrain_fn(self, pruning_action=pruning_action) msglogger.info("Constrained pruning_action={}".format(pruning_action)) + # Calculate the final compression rate total_macs_before, _ = self.net_wrapper.get_model_resources_requirements(self.model) layer_macs = self.net_wrapper.get_layer_macs(self.current_layer()) msglogger.info("\tlayer_macs={:.2f}".format(layer_macs / self.dense_model_macs)) @@ -599,7 +615,7 @@ class DistillerWrapperEnvironment(gym.Env): if pruning_action > 0: pruning_action = self.net_wrapper.remove_structures(self.current_layer_id, fraction_to_prune=pruning_action, - prune_what="filters") + prune_what=self.amc_cfg.pruning_pattern) else: pruning_action = 0