diff --git a/distiller/policy.py b/distiller/policy.py index ca3f1c1a8de2b9fa6da9b84b75c786c85ff5938a..76d295f5f059c14867203479c883b1aa1e9d66d6 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -40,18 +40,18 @@ class ScheduledTrainingPolicy(object): """A new epcoh is about to begin""" pass - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizr=None): """The forward-pass of a new mini-batch is about to begin""" pass def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, - regularizer_loss, zeros_mask_dict): + regularizer_loss, zeros_mask_dict, optimizer=None): """The mini-batch training pass has completed the forward-pass, and is about to begin the backward pass. """ pass - def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): """The mini-batch training pass has ended""" pass @@ -81,7 +81,7 @@ class PruningPolicy(ScheduledTrainingPolicy): for param_name, param in model.named_parameters(): self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): for param_name, param in model.named_parameters(): zeros_mask_dict[param_name].apply_mask(param) @@ -100,11 +100,11 @@ class RegularizationPolicy(ScheduledTrainingPolicy): self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, - regularizer_loss, zeros_mask_dict): + regularizer_loss, zeros_mask_dict, optimizer=None): for param_name, param in model.named_parameters(): self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict) - def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): if self.regularizer.threshold_criteria is None: return @@ -141,7 +141,7 @@ class QuantizationPolicy(ScheduledTrainingPolicy): self.quantizer.prepare_model() self.quantizer.quantize_params() - def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): # After parameters update, quantize the parameters again # (Doing this here ensures the model parameters are quantized at training completion (and at validation time) self.quantizer.quantize_params() diff --git a/distiller/scheduler.py b/distiller/scheduler.py index e4c1d140c51a91a1de604936bbdf5dbe60d0fd92..ba7a56b95b14c5232eeb7640cc17e75f9b5a8932 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -86,24 +86,24 @@ class CompressionScheduler(object): self.policies[epoch].append(policy) assert len(self.policies[epoch]) > 0 - self.sched_metadata[policy] = { 'starting_epoch' : starting_epoch, - 'ending_epoch' : ending_epoch, - 'frequency' : frequency } + self.sched_metadata[policy] = {'starting_epoch': starting_epoch, + 'ending_epoch': ending_epoch, + 'frequency': frequency} - def on_epoch_begin(self, epoch): + def on_epoch_begin(self, epoch, optimizer=None): if epoch in self.policies: for policy in self.policies[epoch]: meta = self.sched_metadata[policy] meta['current_epoch'] = epoch policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta) - def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch): + def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): if epoch in self.policies: for policy in self.policies[epoch]: policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch, - self.zeros_mask_dict) + self.zeros_mask_dict, optimizer) - def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss): + def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None): # Last chance to compute the regularization loss, and optionally add it to the data loss regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device) @@ -114,25 +114,26 @@ class CompressionScheduler(object): loss, regularizer_loss, self.zeros_mask_dict) return regularizer_loss - def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch): + def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): # When we get to this point, the weights are no longer maksed. This is because during the backward # pass, the weights are updated. So we choose to lazily apply the pruning mask, only if some # component is being called-back. weights_are_masked = False - + if epoch in self.policies: for policy in self.policies[epoch]: if not weights_are_masked: self.apply_mask() weights_are_masked = True policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch, - self.zeros_mask_dict) + self.zeros_mask_dict, optimizer) - def on_epoch_end(self, epoch): + def on_epoch_end(self, epoch, optimizer=None): if epoch in self.policies: for policy in self.policies[epoch]: meta = self.sched_metadata[policy] meta['current_epoch'] = epoch + meta['optimizer'] = optimizer policy.on_epoch_end(self.model, self.zeros_mask_dict, meta) def apply_mask(self): diff --git a/distiller/thinning.py b/distiller/thinning.py index eb27e52624057c6457627b96664c0c0c006b2492..94346f34750c26763e564a4150ad816f3bfaeaee 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -149,10 +149,10 @@ def resnet_cifar_remove_layers(model): model.module.layer_gates[layer][block][conv] = False -def remove_channels(model, zeros_mask_dict, arch, dataset): +def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): sgraph = create_graph(dataset, arch) thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) - apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe) + apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model @@ -188,10 +188,10 @@ def find_nonzero_channels_list(param, param_name): return nnz_channels.cpu().numpy().tolist() -def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe): +def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0: # Now actually remove the filters, chaneels and make the weight tensors smaller - execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe) + execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) # Stash the recipe, so that it will be serialized together with the model if hasattr(model, 'thinning_recipes'): @@ -204,10 +204,10 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe): msglogger.error("Failed to create a thinning recipe") -def remove_filters(model, zeros_mask_dict, arch, dataset): +def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): sgraph = create_graph(dataset, arch) thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) - apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe) + apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model @@ -375,7 +375,7 @@ class ChannelRemover(ScheduledTrainingPolicy): self.dataset = dataset def on_epoch_end(self, model, zeros_mask_dict, meta): - self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset) + self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, meta.get('optimizer', None)) class FilterRemover(ScheduledTrainingPolicy): @@ -387,24 +387,24 @@ class FilterRemover(ScheduledTrainingPolicy): self.done = False self.active_cb = "on_minibatch_begin" - def __apply(self, model, zeros_mask_dict): + def __apply(self, model, zeros_mask_dict, optimizer): if not self.done: # We want to execute the thinning function only once, not every invocation of on_minibatch_begin - self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset) + self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, optimizer=optimizer) self.done = True - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): # We hook onto the on_minibatch_begin because we want to run after the pruner which sparsified # the tensors. Pruners configure their pruning mask in on_epoch_begin, but apply the mask # only in on_minibatch_begin if self.active_cb != "on_minibatch_begin": return - self.__apply(model, zeros_mask_dict) + self.__apply(model, zeros_mask_dict, optimizer) - def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): if self.active_cb != "on_minibatch_end": return - self.__apply(model, zeros_mask_dict) + self.__apply(model, zeros_mask_dict, optimizer) def on_epoch_end(self, model, zeros_mask_dict, meta): # The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused @@ -416,15 +416,45 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): # to a thinned model. For example, this is invoked when loading a model from a checkpoint. for i, recipe in enumerate(recipe_list): msglogger.info("recipe %d:" % i) - execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True) + execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer=None, loaded_from_file=True) -def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False): +def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): + """Adjust the size of the SGD vecolity-tracking tensors. + + The SGD momentum update (velocity) is dependent on the weights, and because during thinning we + dynamically change the weights shapes, we need to make the apporpriate changes in the Optimizer, + or disable the momentum. + + This function is brittle as it is tested on SGD only and relies on the internal representation of + the SGD optimizer, which can change w/o notice. + """ + if optimizer is None or not isinstance(optimizer, torch.optim.SGD): + return False + for group in optimizer.param_groups: + momentum = group.get('momentum', 0) + if momentum == 0: + continue + for p in group['params']: + if id(p) != id(param): + continue + param_state = optimizer.state[p] + if 'momentum_buffer' in param_state: + param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices) + if new_shape is not None: + msglogger.info("optimizer_thinning: new shape {}".format(*new_shape)) + param_state['momentum_buffer'] = param_state['momentum_buffer'].resize_(*new_shape) + return True + return False + + +def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_from_file=False): """Apply a thinning recipe to a model. This will remove filters and channels, as well as handle batch-normalization parameter adjustment, and thinning of weight tensors. """ + layers = {} for name, m in model.named_modules(): layers[name] = m @@ -458,12 +488,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal # Check if we're trying to trim a parameter that is already "thin" if param.data.size(dim) != len_indices: param.data = torch.index_select(selection_view, dim, indices) - - if param.grad is not None: - # We also need to change the dimensions of the gradient tensor. - grad_selection_view = param.grad.resize_(*directive[2]) - if grad_selection_view.size(dim) != len_indices: - param.grad = torch.index_select(grad_selection_view, dim, indices) + if param.grad is not None: + # We also need to change the dimensions of the gradient tensor. + grad_selection_view = param.grad.resize_(*directive[2]) + if grad_selection_view.size(dim) != len_indices: + param.grad = torch.index_select(grad_selection_view, dim, indices) + if optimizer_thinning(optimizer, param, dim, indices, directive[3]): + msglogger.info("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". + format(param_name, dim, len_indices, directive[3])) param.data = param.view(*directive[3]) if param.grad is not None: @@ -471,12 +503,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal else: if param.data.size(dim) != len_indices: param.data = torch.index_select(param.data, dim, indices) + msglogger.info("[thinning] changed param {} shape: {}".format(param_name, len_indices)) # We also need to change the dimensions of the gradient tensor. # If have not done a backward-pass thus far, then the gradient will # not exist, and therefore won't need to be re-dimensioned. if param.grad is not None and param.grad.size(dim) != len_indices: - param.grad = torch.index_select(param.grad, dim, indices) - msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices)) + param.grad = torch.index_select(param.grad, dim, indices) + if optimizer_thinning(optimizer, param, dim, indices): + msglogger.info("Updated velocity buffer %s" % param_name) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index f33e171980a20646649a5b0dc3dd82b9b957f99c..c6db0b880ba27097f38de87ce7aa5441f26190cb 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -223,10 +223,6 @@ def main(): model, compression_scheduler, start_epoch = apputils.load_checkpoint( model, chkpt_file=args.resume) - if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch: - distiller.resnet_cifar_remove_layers(model) - #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict) - # Define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, @@ -320,11 +316,11 @@ def main(): OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5)])) - distiller.log_training_progress(stats, None,epoch, steps_completed=0, total_steps=1, + distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, loggers=[tflogger]) if compression_scheduler: - compression_scheduler.on_epoch_end(epoch) + compression_scheduler.on_epoch_end(epoch, optimizer) # remember best top1 and save checkpoint is_best = top1 > best_top1 @@ -339,8 +335,8 @@ def main(): def train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, loggers, print_freq, log_params_hist): """Training loop for one epoch.""" - losses = {'objective_loss' : tnt.AverageValueMeter(), - 'regularizer_loss' : tnt.AverageValueMeter()} + losses = {'objective_loss': tnt.AverageValueMeter(), + 'regularizer_loss': tnt.AverageValueMeter()} if compression_scheduler is None: # Initialize the regularizer loss to zero losses['regularizer_loss'].add(0) @@ -368,7 +364,7 @@ def train(train_loader, model, criterion, optimizer, epoch, # Execute the forward phase, compute the output and measure loss if compression_scheduler: - compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch) + compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer) output = model(input_var) loss = criterion(output, target_var) @@ -378,7 +374,7 @@ def train(train_loader, model, criterion, optimizer, epoch, if compression_scheduler: # Before running the backward phase, we add any regularization loss computed by the scheduler - regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss) + regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, optimizer) loss += regularizer_loss losses['regularizer_loss'].add(regularizer_loss.item()) @@ -387,7 +383,7 @@ def train(train_loader, model, criterion, optimizer, epoch, loss.backward() optimizer.step() if compression_scheduler: - compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch) + compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch, optimizer) # measure elapsed time batch_time.add(time.time() - end) diff --git a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml index 9efc6dc69c1362538721c6b723324d0750f2b1e3..cb20dd1ea5aea4e8d58803ce14b35cffaf10990b 100755 --- a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml +++ b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml @@ -2,7 +2,7 @@ # This schedule performs 3D (filter-wise) regularization of some of the convolution layers, together with # element-wise pruning using sensitivity-pruning. # -# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained --momentum=0 +# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained # diff --git a/examples/ssl/ssl_4D-removal_training.yaml b/examples/ssl/ssl_4D-removal_training.yaml index 0fcbe0f8dd95e569d9dab2dcf2440f528c587981..ef1b120452e8b07a8e2e595c4d76589bf2bec93a 100755 --- a/examples/ssl/ssl_4D-removal_training.yaml +++ b/examples/ssl/ssl_4D-removal_training.yaml @@ -1,3 +1,13 @@ +# +# THIS IS CURRENTLY BROKEN! +# There are two challenges to getting this to work: +# 1. If we want to remove layers, we should use a thinning recipe and leave instructions. Otherwise, it's very hard to +# decide if we should remove layers or not. +# 2. We need a generic solution for removing layers. Perhaps a new module type which wraps a layer and has a bypass-gate. +# +# If you still want to use this, until we implement a proper solution, invoke thinning.resnet_cifar_remove_layers from +# compress_classifier.py. +# # We used this schedule to train CIFAR10-ResNet20 from scratch with SSL. # After running this schedule, use ssl_4D-removal_finetuning.yaml to "physically" remove the layers and fine-tune the smaller model. # diff --git a/tests/test_pruning.py b/tests/test_pruning.py index b6c6a06e1a91c50b6ac2509f1ab28041651bfc19..b16e9012af1941d7cbbfbe86b106c8927c9cc494 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -35,7 +35,7 @@ fh = logging.FileHandler('test.log') logger = logging.getLogger() logger.addHandler(fh) -NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_name") +NetConfig = namedtuple("test_config", "arch dataset bn_name module_pairs") # @@ -43,20 +43,34 @@ NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_nam # def simplenet(): return NetConfig(arch="simplenet_cifar", dataset="cifar10", - conv1_name="conv1", conv2_name="conv2", + module_pairs=[("conv1", "conv2")], bn_name=None) def resnet20_cifar(): return NetConfig(arch="resnet20_cifar", dataset="cifar10", - conv1_name="layer1.0.conv1", conv2_name="layer1.0.conv2", + module_pairs=[("layer1.0.conv1", "layer1.0.conv2")], bn_name="layer1.0.bn1") +def vgg19_imagenet(): + return NetConfig(arch="vgg19", dataset="imagenet", + module_pairs=[("features.21", "features.23"), + ("features.23", "features.25"), + ("features.25", "features.28"), + ("features.28", "features.30"), + ("features.30", "features.32"), + ("features.32", "features.34")], + bn_name=None) + + def test_ranked_filter_pruning(): ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.1) ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.5) ranked_filter_pruning(simplenet(), ratio_to_prune=0.5) + ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) + model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) + test_conv_fc_interface(model, zeros_mask_dict) def test_prune_all_filters(): @@ -76,44 +90,46 @@ def ranked_filter_pruning(config, ratio_to_prune): """ model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) - # Test that we can access the weights tensor of the first convolution in layer 1 - conv1_p = distiller.model_find_param(model, config.conv1_name + ".weight") - assert conv1_p is not None - num_filters = conv1_p.size(0) - - # Test that there are no zero-filters - assert distiller.sparsity_3D(conv1_p) == 0.0 - - # Create a filter-ranking pruner - reg_regims = {config.conv1_name + ".weight": [ratio_to_prune, "3D"]} - pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) - pruner.set_param_mask(conv1_p, config.conv1_name + ".weight", zeros_mask_dict, meta=None) - - conv1 = common.find_module_by_name(model, config.conv1_name) - assert conv1 is not None - # Test that the mask has the correct fraction of filters pruned. - # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters - expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels) - expected_pruning = expected_cnt_removed_filters / conv1.out_channels - masker = zeros_mask_dict[config.conv1_name + ".weight"] - assert masker is not None - assert distiller.sparsity_3D(masker.mask) == expected_pruning - - # Use the mask to prune - assert distiller.sparsity_3D(conv1_p) == 0 - masker.apply_mask(conv1_p) - assert distiller.sparsity_3D(conv1_p) == expected_pruning - - # Remove filters - conv2 = common.find_module_by_name(model, config.conv2_name) - assert conv2 is not None - assert conv1.out_channels == num_filters - assert conv2.in_channels == num_filters + for pair in config.module_pairs: + # Test that we can access the weights tensor of the first convolution in layer 1 + conv1_p = distiller.model_find_param(model, pair[0] + ".weight") + assert conv1_p is not None + num_filters = conv1_p.size(0) + + # Test that there are no zero-filters + assert distiller.sparsity_3D(conv1_p) == 0.0 + + # Create a filter-ranking pruner + reg_regims = {pair[0] + ".weight": [ratio_to_prune, "3D"]} + pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) + pruner.set_param_mask(conv1_p, pair[0] + ".weight", zeros_mask_dict, meta=None) + + conv1 = common.find_module_by_name(model, pair[0]) + assert conv1 is not None + # Test that the mask has the correct fraction of filters pruned. + # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters + expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels) + expected_pruning = expected_cnt_removed_filters / conv1.out_channels + masker = zeros_mask_dict[pair[0] + ".weight"] + assert masker is not None + assert distiller.sparsity_3D(masker.mask) == expected_pruning + + # Use the mask to prune + assert distiller.sparsity_3D(conv1_p) == 0 + masker.apply_mask(conv1_p) + assert distiller.sparsity_3D(conv1_p) == expected_pruning + + # Remove filters + conv2 = common.find_module_by_name(model, pair[1]) + assert conv2 is not None + assert conv1.out_channels == num_filters + assert conv2.in_channels == num_filters # Test thinning - distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset) + distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None) assert conv1.out_channels == num_filters - expected_cnt_removed_filters assert conv2.in_channels == num_filters - expected_cnt_removed_filters + return model, zeros_mask_dict def test_arbitrary_channel_pruning(): @@ -134,6 +150,40 @@ def test_channel_pruning_conv_bias(): arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 1]) +def create_channels_mask(conv_p, channels_to_remove): + assert conv_p.dim() == 4 + num_filters = conv_p.size(0) + num_channels = conv_p.size(1) + kernel_height = conv_p.size(2) + kernel_width = conv_p.size(3) + + # Let's build our 4D mask. + # We start with a 1D mask of channels, with all but our specified channels set to one + channels = torch.ones(num_channels) + for ch in channels_to_remove: + channels[ch] = 0 + + # Now let's expand back up to a 4D mask + mask = channels.expand(num_filters, num_channels) + mask.unsqueeze_(-1) + mask.unsqueeze_(-1) + mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() + + assert mask.shape == conv_p.shape + return mask + + +def run_forward_backward(model, optimizer, dummy_input): + criterion = torch.nn.CrossEntropyLoss().cuda() + model.train() + output = model(dummy_input) + target = torch.LongTensor(1).random_(2) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def arbitrary_channel_pruning(config, channels_to_remove): """Test removal of arbitrary channels. @@ -143,49 +193,34 @@ def arbitrary_channel_pruning(config, channels_to_remove): """ model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) - conv2 = common.find_module_by_name(model, config.conv2_name) + assert len(config.module_pairs) == 1 # This is a temporary restriction on the test + pair = config.module_pairs[0] + conv2 = common.find_module_by_name(model, pair[1]) assert conv2 is not None # Test that we can access the weights tensor of the first convolution in layer 1 - conv2_p = distiller.model_find_param(model, config.conv2_name + ".weight") + conv2_p = distiller.model_find_param(model, pair[1] + ".weight") assert conv2_p is not None assert conv2_p.dim() == 4 - num_filters = conv2_p.size(0) num_channels = conv2_p.size(1) - kernel_height = conv2_p.size(2) - kernel_width = conv2_p.size(3) cnt_nnz_channels = num_channels - len(channels_to_remove) - - # Let's build our 4D mask. - # We start with a 1D mask of channels, with all but our specified channels set to one - channels = torch.ones(num_channels) - for ch in channels_to_remove: - channels[ch] = 0 - - # Now let's expand back up to a 4D mask - mask = channels.expand(num_filters, num_channels) - mask.unsqueeze_(-1) - mask.unsqueeze_(-1) - mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() - - assert mask.shape == conv2_p.shape + mask = create_channels_mask(conv2_p, channels_to_remove) assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels - # Cool, so now we have a mask for pruning our channels. + # Use the mask to prune - zeros_mask_dict[config.conv2_name + ".weight"].mask = mask - zeros_mask_dict[config.conv2_name + ".weight"].apply_mask(conv2_p) + zeros_mask_dict[pair[1] + ".weight"].mask = mask + zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p) all_channels = set([ch for ch in range(num_channels)]) - nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, config.conv2_name + ".weight")) + nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight")) channels_removed = all_channels - nnz_channels logger.info("Channels removed {}".format(channels_removed)) # Now, let's do the actual network thinning - distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset) - conv1 = common.find_module_by_name(model, config.conv1_name) - logger.info(conv1) - logger.info(conv2) + distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None) + conv1 = common.find_module_by_name(model, pair[0]) + assert conv1.out_channels == cnt_nnz_channels assert conv2.in_channels == cnt_nnz_channels assert conv1.weight.size(0) == cnt_nnz_channels @@ -198,6 +233,10 @@ def arbitrary_channel_pruning(config, channels_to_remove): assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels + dummy_input = torch.randn(1, 3, 32, 32) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) + run_forward_backward(model, optimizer, dummy_input) + # Let's test saving and loading a thinned model. # We save 3 times, and load twice, to make sure to cover some corner cases: # - Make sure that after loading, the model still has hold of the thinning recipes @@ -206,16 +245,17 @@ def arbitrary_channel_pruning(config, channels_to_remove): # (1) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) model_2 = create_model(False, config.dataset, config.arch, parallel=False) - dummy_input = torch.randn(1, 3, 32, 32) model(dummy_input) model_2(dummy_input) - conv2 = common.find_module_by_name(model_2, config.conv2_name) + conv2 = common.find_module_by_name(model_2, pair[1]) assert conv2 is not None with pytest.raises(KeyError): model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') compression_scheduler = distiller.CompressionScheduler(model) hasattr(model, 'thinning_recipes') + run_forward_backward(model, optimizer, dummy_input) + # (2) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') @@ -229,7 +269,62 @@ def arbitrary_channel_pruning(config, channels_to_remove): logger.info("test_arbitrary_channel_pruning - Done 2") +def test_conv_fc_interface(model=None, zeros_mask_dict=None): + """A special case of convolution filter-pruning occurs when the next layer is + fully-connected (linear). This test is for this case and uses VGG16. + """ + arch = "vgg19" + dataset = "imagenet" + ratio_to_prune = 0.1 + conv_name = "features.34" + fc_name = "classifier.0" + dummy_input = torch.randn(1, 3, 224, 224) + + if model is None or zeros_mask_dict is None: + model, zeros_mask_dict = common.setup_test(arch, dataset) + + # Run forward and backward passes, in order to create the gradients and optimizer params + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) + run_forward_backward(model, optimizer, dummy_input) + + conv = common.find_module_by_name(model, conv_name) + assert conv is not None + + conv_p = distiller.model_find_param(model, conv_name + ".weight") + assert conv_p is not None + assert conv_p.dim() == 4 + + # Create a filter-ranking pruner + reg_regims = {conv_name + ".weight": [ratio_to_prune, "3D"]} + pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) + pruner.set_param_mask(conv_p, conv_name + ".weight", zeros_mask_dict, meta=None) + + # Use the mask to prune + masker = zeros_mask_dict[conv_name + ".weight"] + assert masker is not None + masker.apply_mask(conv_p) + num_filters = conv_p.size(0) + expected_cnt_removed_filters = int(ratio_to_prune * conv.out_channels) + + # Remove filters + fc = common.find_module_by_name(model, fc_name) + assert fc is not None + + # Test thinning + fm_size = fc.in_features // conv.out_channels + num_nnz_filters = num_filters - expected_cnt_removed_filters + distiller.remove_filters(model, zeros_mask_dict, arch, dataset, optimizer) + assert conv.out_channels == num_nnz_filters + assert fc.in_features == fm_size * num_nnz_filters + + # Run again, to make sure the optimizer and gradients shapes were updated correctly + run_forward_backward(model, optimizer, dummy_input) + run_forward_backward(model, optimizer, dummy_input) + + if __name__ == '__main__': test_ranked_filter_pruning() test_arbitrary_channel_pruning() test_prune_all_channels() + model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) + test_conv_fc_interface(model, zeros_mask_dict)