Skip to content
Snippets Groups Projects
Commit d6ffeaf7 authored by Neta Zmora's avatar Neta Zmora
Browse files

Baidu RNN pruner: add documentation and fix schedule

parent ecade1b2
No related branches found
No related tags found
No related merge requests found
......@@ -26,18 +26,45 @@ class BaiduRNNPruner(_ParameterPruner):
Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017).
Exploring Sparsity in Recurrent Neural Networks.
(https://arxiv.org/abs/1704.05119)
This implementation slightly differs from the algorithm original paper in that
the algorithm changes the pruning rate at the training-step granularity, while
Distiller controls the pruning rate at epoch granularity.
Equation (1):
2 * q * freq
start_slope = -------------------------------------------------------
2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr )
Pruning algorithm (1):
if current itr < ramp itr then
threshold = start_slope * (current_itr - start_itr + 1) / freq
else
threshold = (start_slope * (ramp_itr - start_itr + 1) +
ramp_slope * (current_itr - ramp_itr + 1)) / freq
end if
mask = abs(param) < threshold
"""
def __init__(self, name, initial_sparsity, final_sparsity, q, ramp_epoch, ramp_slope_mult, weights):
def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights):
# Initialize the pruner, using a configuration that originates from the
# schedule YAML file.
super(BaiduRNNPruner, self).__init__(name)
self.initial_sparsity = initial_sparsity
self.final_sparsity = final_sparsity
assert final_sparsity > initial_sparsity
self.params_names = weights
assert self.params_names
# This is the 'q' value that appears in equation (1) of the paper
self.q = q
self.ramp_epoch = ramp_epoch
# This is the number of epochs to wait after starting_epoch, before we
# begin ramping up the pruning rate.
# In other words, between epochs 'starting_epoch' and 'starting_epoch'+
# self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After
# that, the slope is 'self.ramp_slope'
self.ramp_epoch_offset = ramp_epoch_offset
self.ramp_slope_mult = ramp_slope_mult
self.ramp_slope = None
self.start_slope = None
......@@ -51,16 +78,19 @@ class BaiduRNNPruner(_ParameterPruner):
ending_epoch = meta['ending_epoch']
freq = meta['frequency']
ramp_epoch = self.ramp_epoch_offset + starting_epoch
# Calculate start slope
if self.start_slope is None:
self.start_slope = (2 * self.q * freq) / (2*(self.ramp_epoch - starting_epoch) + 3*(ending_epoch - self.ramp_epoch))
# We want to calculate these values only once, and then cache them.
self.start_slope = (2 * self.q * freq) / (2*(ramp_epoch - starting_epoch) + 3*(ending_epoch - ramp_epoch))
self.ramp_slope = self.start_slope * self.ramp_slope_mult
if current_epoch < self.ramp_epoch:
if current_epoch < ramp_epoch:
eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq
else:
eps = (self.start_slope * (self.ramp_epoch - starting_epoch + 1) +
self.ramp_slope * (current_epoch - self.ramp_epoch + 1)) / freq
eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) +
self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq
# After computing the threshold, we can create the mask
zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps)
......@@ -37,46 +37,36 @@ version: 1
pruners:
ih_l0_rnn_pruner:
class: BaiduRNNPruner
initial_sparsity : 0.05
final_sparsity: 0.50
q: 0.17
ramp_epoch: 2
ramp_epoch_offset: 3
ramp_slope_mult: 2
weights: [rnn.weight_ih_l0]
hh_l0_rnn_pruner:
class: BaiduRNNPruner
initial_sparsity : 0.05
final_sparsity: 0.50
q: 0.11
ramp_epoch: 2
ramp_epoch_offset: 3
ramp_slope_mult: 2
weights: [rnn.weight_hh_l0]
ih_l1_rnn_pruner:
class: BaiduRNNPruner
initial_sparsity : 0.05
final_sparsity: 0.60
q: 0.18
ramp_epoch: 2
ramp_epoch_offset: 3
ramp_slope_mult: 2
weights: [rnn.weight_ih_l1]
hh_l1_rnn_pruner:
class: BaiduRNNPruner
initial_sparsity : 0.05
final_sparsity: 0.60
q: 0.15
ramp_epoch: 2
ramp_epoch_offset: 3
ramp_slope_mult: 2
weights: [rnn.weight_hh_l1]
embedding_pruner:
class: BaiduRNNPruner
initial_sparsity : 0.05
final_sparsity: 0.80
q: 0.16
ramp_epoch: 2
ramp_epoch_offset: 3
ramp_slope_mult: 2
weights: [encoder.weight]
......@@ -84,29 +74,29 @@ policies:
- pruner:
instance_name : ih_l0_rnn_pruner
starting_epoch: 4
ending_epoch: 20
frequency: 1
ending_epoch: 21
frequency: 3
- pruner:
instance_name : hh_l0_rnn_pruner
starting_epoch: 4
ending_epoch: 20
frequency: 1
ending_epoch: 21
frequency: 3
- pruner:
instance_name : ih_l1_rnn_pruner
starting_epoch: 4
ending_epoch: 20
frequency: 1
starting_epoch: 5
ending_epoch: 22
frequency: 3
- pruner:
instance_name : hh_l1_rnn_pruner
starting_epoch: 4
ending_epoch: 20
frequency: 1
starting_epoch: 5
ending_epoch: 22
frequency: 3
- pruner:
instance_name : embedding_pruner
starting_epoch: 5
ending_epoch: 21
frequency: 1
starting_epoch: 6
ending_epoch: 23
frequency: 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment