Skip to content
Snippets Groups Projects
Commit d8c97cdd authored by Neta Zmora's avatar Neta Zmora
Browse files

AMC: Revive support for weights-channels removal

This is in contrast to weights-filters removal
parent 09d2eea3
No related branches found
No related tags found
No related merge requests found
......@@ -226,14 +226,14 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
# Create a dictionary of parameters that Coach will handover to DistillerWrapperEnvironment
# Once it creates it.
services = distiller.utils.MutableNamedTuple({
'validate_fn': validate_fn,
'save_checkpoint_fn': save_checkpoint_fn,
'train_fn': train_fn})
'validate_fn': validate_fn,
'save_checkpoint_fn': save_checkpoint_fn,
'train_fn': train_fn})
app_args = distiller.utils.MutableNamedTuple({
'dataset': dataset,
'arch': arch,
'optimizer_data': optimizer_data})
'dataset': dataset,
'arch': arch,
'optimizer_data': optimizer_data})
amc_cfg = distiller.utils.MutableNamedTuple({
'protocol': args.amc_protocol,
......@@ -243,7 +243,8 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
'action_range': action_range,
'conv_cnt': conv_cnt,
'reward_frequency': args.amc_reward_frequency,
'ft_frequency': args.amc_ft_frequency})
'ft_frequency': args.amc_ft_frequency,
'pruning_pattern': "filters"}) # "channels"}) #
#net_wrapper = NetworkWrapper(model, app_args, services)
#return sample_networks(net_wrapper, services)
......@@ -344,9 +345,21 @@ resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight
"module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight",
"module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"]
plain20_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight",
"module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight",
"module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight",
"module.layer2.0.conv1.weight", "module.layer2.0.conv2.weight",
"module.layer2.1.conv1.weight", "module.layer2.1.conv2.weight",
"module.layer2.2.conv1.weight", "module.layer2.2.conv2.weight",
"module.layer3.0.conv1.weight", "module.layer3.0.conv2.weight",
"module.layer3.1.conv1.weight", "module.layer3.1.conv2.weight",
"module.layer3.2.conv1.weight", "module.layer3.2.conv2.weight"]
resnet50_layers = [param[:-len(".weight")] for param in resnet50_params]
resnet20_layers = [param[:-len(".weight")] for param in resnet20_params]
resnet56_layers = [param[:-len(".weight")] for param in resnet56_params]
plain20_layers = [param[:-len(".weight")] for param in plain20_params]
class NetworkWrapper(object):
......@@ -375,6 +388,8 @@ class NetworkWrapper(object):
resnet_layers = resnet56_layers
elif self.app_args.arch == "resnet50":
resnet_layers = resnet50_layers
elif self.app_args.arch == "plain20_cifar":
resnet_layers = plain20_layers
return collect_conv_details(model, self.app_args.dataset, True, resnet_layers)
def num_layers(self):
......@@ -426,16 +441,16 @@ class NetworkWrapper(object):
conv_pname = layer.name + ".weight"
conv_p = distiller.model_find_param(self.model, conv_pname)
msglogger.info("ADC: removing %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname))
msglogger.info("ADC: trying to remove %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname))
if prune_what == "channels":
calculate_sparsity = distiller.sparsity_ch
remove_structures = distiller.remove_channels
remove_structures_fn = distiller.remove_channels
group_type = "Channels"
elif prune_what == "filters":
calculate_sparsity = distiller.sparsity_3D
group_type = "Filters"
remove_structures = distiller.remove_filters
remove_structures_fn = distiller.remove_filters
else:
raise ValueError("unsupported structure {}".format(prune_what))
# Create a channel-ranking pruner
......@@ -446,12 +461,12 @@ class NetworkWrapper(object):
if (self.zeros_mask_dict[conv_pname].mask is None or
calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0):
msglogger.info("remove_structures: aborting because there are no channels to prune")
msglogger.info("remove_structures: aborting because there are no structures to prune")
return 0
# Use the mask to prune
self.zeros_mask_dict[conv_pname].apply_mask(conv_p)
remove_structures(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None)
remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None)
conv_p = distiller.model_find_param(self.model, conv_pname)
return 1 - (self.get_layer_macs(layer) / macs_before)
......@@ -476,8 +491,6 @@ class NetworkWrapper(object):
class DistillerWrapperEnvironment(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self, model, app_args, amc_cfg, services):
self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
self.tflogger = distiller.data_loggers.TensorBoardLogger(msglogger.logdir)
......@@ -573,23 +586,26 @@ class DistillerWrapperEnvironment(gym.Env):
def step(self, pruning_action):
"""Take a step, given an action.
The action represents the desired sparsity.
The action represents the desired sparsity for the "current" layer.
This function is invoked by the Agent.
"""
msglogger.info("env.step - current_layer_id={} episode={}".format(self.current_layer_id, self.episode))
pruning_action = pruning_action[0]
msglogger.info("\tAgent pruning_action={}".format(pruning_action))
self.agent_action_history.append(pruning_action)
if is_using_continuous_action_space(self.amc_cfg.agent_algo):
pruning_action = np.clip(pruning_action[0], self.action_low, self.action_high)
pruning_action = np.clip(pruning_action, self.action_low, self.action_high)
else:
# Divide the action space into 10 discrete levels (0%, 10%, 20%,....90% sparsity)
pruning_action = pruning_action / 10
msglogger.info("\tAgent clipped pruning_action={}".format(pruning_action))
self.agent_action_history.append(pruning_action)
if self.amc_cfg.action_constrain_fn is not None:
pruning_action = self.amc_cfg.action_constrain_fn(self, pruning_action=pruning_action)
msglogger.info("Constrained pruning_action={}".format(pruning_action))
# Calculate the final compression rate
total_macs_before, _ = self.net_wrapper.get_model_resources_requirements(self.model)
layer_macs = self.net_wrapper.get_layer_macs(self.current_layer())
msglogger.info("\tlayer_macs={:.2f}".format(layer_macs / self.dense_model_macs))
......@@ -599,7 +615,7 @@ class DistillerWrapperEnvironment(gym.Env):
if pruning_action > 0:
pruning_action = self.net_wrapper.remove_structures(self.current_layer_id,
fraction_to_prune=pruning_action,
prune_what="filters")
prune_what=self.amc_cfg.pruning_pattern)
else:
pruning_action = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment