diff --git a/distiller/pruning/baidu_rnn_pruner.py b/distiller/pruning/baidu_rnn_pruner.py index 79ca8721dc0af895b7b152121c4ab8e4685118eb..f195b8598cbc2614b9e4d7522e6672dc13bffd6c 100755 --- a/distiller/pruning/baidu_rnn_pruner.py +++ b/distiller/pruning/baidu_rnn_pruner.py @@ -26,18 +26,45 @@ class BaiduRNNPruner(_ParameterPruner): Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). Exploring Sparsity in Recurrent Neural Networks. (https://arxiv.org/abs/1704.05119) + + This implementation slightly differs from the algorithm original paper in that + the algorithm changes the pruning rate at the training-step granularity, while + Distiller controls the pruning rate at epoch granularity. + + Equation (1): + + 2 * q * freq + start_slope = ------------------------------------------------------- + 2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr ) + + + Pruning algorithm (1): + + if current itr < ramp itr then + threshold = start_slope * (current_itr - start_itr + 1) / freq + else + threshold = (start_slope * (ramp_itr - start_itr + 1) + + ramp_slope * (current_itr - ramp_itr + 1)) / freq + end if + + mask = abs(param) < threshold """ - def __init__(self, name, initial_sparsity, final_sparsity, q, ramp_epoch, ramp_slope_mult, weights): + def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): + # Initialize the pruner, using a configuration that originates from the + # schedule YAML file. super(BaiduRNNPruner, self).__init__(name) - self.initial_sparsity = initial_sparsity - self.final_sparsity = final_sparsity - assert final_sparsity > initial_sparsity self.params_names = weights assert self.params_names + # This is the 'q' value that appears in equation (1) of the paper self.q = q - self.ramp_epoch = ramp_epoch + # This is the number of epochs to wait after starting_epoch, before we + # begin ramping up the pruning rate. + # In other words, between epochs 'starting_epoch' and 'starting_epoch'+ + # self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After + # that, the slope is 'self.ramp_slope' + self.ramp_epoch_offset = ramp_epoch_offset self.ramp_slope_mult = ramp_slope_mult self.ramp_slope = None self.start_slope = None @@ -51,16 +78,19 @@ class BaiduRNNPruner(_ParameterPruner): ending_epoch = meta['ending_epoch'] freq = meta['frequency'] + ramp_epoch = self.ramp_epoch_offset + starting_epoch + # Calculate start slope if self.start_slope is None: - self.start_slope = (2 * self.q * freq) / (2*(self.ramp_epoch - starting_epoch) + 3*(ending_epoch - self.ramp_epoch)) + # We want to calculate these values only once, and then cache them. + self.start_slope = (2 * self.q * freq) / (2*(ramp_epoch - starting_epoch) + 3*(ending_epoch - ramp_epoch)) self.ramp_slope = self.start_slope * self.ramp_slope_mult - if current_epoch < self.ramp_epoch: + if current_epoch < ramp_epoch: eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq else: - eps = (self.start_slope * (self.ramp_epoch - starting_epoch + 1) + - self.ramp_slope * (current_epoch - self.ramp_epoch + 1)) / freq + eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) + + self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq # After computing the threshold, we can create the mask zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) diff --git a/examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml b/examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml index b2c004e499cdc97033f7ec236274ea4f11c4c77e..f714a0a295e425fe323bb0691a4df499d2e9a9e8 100755 --- a/examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml +++ b/examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml @@ -37,46 +37,36 @@ version: 1 pruners: ih_l0_rnn_pruner: class: BaiduRNNPruner - initial_sparsity : 0.05 - final_sparsity: 0.50 q: 0.17 - ramp_epoch: 2 + ramp_epoch_offset: 3 ramp_slope_mult: 2 weights: [rnn.weight_ih_l0] hh_l0_rnn_pruner: class: BaiduRNNPruner - initial_sparsity : 0.05 - final_sparsity: 0.50 q: 0.11 - ramp_epoch: 2 + ramp_epoch_offset: 3 ramp_slope_mult: 2 weights: [rnn.weight_hh_l0] ih_l1_rnn_pruner: class: BaiduRNNPruner - initial_sparsity : 0.05 - final_sparsity: 0.60 q: 0.18 - ramp_epoch: 2 + ramp_epoch_offset: 3 ramp_slope_mult: 2 weights: [rnn.weight_ih_l1] hh_l1_rnn_pruner: class: BaiduRNNPruner - initial_sparsity : 0.05 - final_sparsity: 0.60 q: 0.15 - ramp_epoch: 2 + ramp_epoch_offset: 3 ramp_slope_mult: 2 weights: [rnn.weight_hh_l1] embedding_pruner: class: BaiduRNNPruner - initial_sparsity : 0.05 - final_sparsity: 0.80 q: 0.16 - ramp_epoch: 2 + ramp_epoch_offset: 3 ramp_slope_mult: 2 weights: [encoder.weight] @@ -84,29 +74,29 @@ policies: - pruner: instance_name : ih_l0_rnn_pruner starting_epoch: 4 - ending_epoch: 20 - frequency: 1 + ending_epoch: 21 + frequency: 3 - pruner: instance_name : hh_l0_rnn_pruner starting_epoch: 4 - ending_epoch: 20 - frequency: 1 + ending_epoch: 21 + frequency: 3 - pruner: instance_name : ih_l1_rnn_pruner - starting_epoch: 4 - ending_epoch: 20 - frequency: 1 + starting_epoch: 5 + ending_epoch: 22 + frequency: 3 - pruner: instance_name : hh_l1_rnn_pruner - starting_epoch: 4 - ending_epoch: 20 - frequency: 1 + starting_epoch: 5 + ending_epoch: 22 + frequency: 3 - pruner: instance_name : embedding_pruner - starting_epoch: 5 - ending_epoch: 21 - frequency: 1 + starting_epoch: 6 + ending_epoch: 23 + frequency: 3