Skip to content
Snippets Groups Projects
Unverified Commit 9f0c0832 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Mag pruner doc (#33)

MagnitudeParameterPruner: document and test

This is in response to a question in issue #19 
parent d97786ee
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@
from .pruner import _ParameterPruner
import distiller
class MagnitudeParameterPruner(_ParameterPruner):
"""This is the most basic magnitude-based pruner.
......@@ -26,8 +27,23 @@ class MagnitudeParameterPruner(_ParameterPruner):
"""
def __init__(self, name, thresholds, **kwargs):
"""
Usually, a Pruner is constructed by the compression schedule parser
found in distiller/config.py.
The constructor is passed a dictionary of thresholds, as explained below.
Args:
name (string): the name of the pruner (used only for debug)
thresholds (dict): a disctionary of thresholds, with the key being the
parameter name.
A special key, '*', represents the default threshold value. If
set_param_mask is invoked on a parameter tensor that does not have
an explicit entry in the 'thresholds' dictionary, then this default
value is used.
Currently it is mandatory to include a '*' key in 'thresholds'.
"""
super(MagnitudeParameterPruner, self).__init__(name)
assert 'thresholds' is not None
assert thresholds is not None
# Make sure there is a default threshold to use
assert '*' in thresholds
self.thresholds = thresholds
......
......@@ -49,3 +49,7 @@ def get_dummy_input(dataset):
elif dataset == "cifar10":
return torch.randn(1, 3, 32, 32).cuda()
raise ValueError("Trying to use an unknown dataset " + dataset)
def almost_equal(a , b, max_diff=0.000001):
return abs(a - b) <= max_diff
......@@ -14,6 +14,7 @@
# limitations under the License.
#
from collections import namedtuple
import numpy as np
import logging
import torch
import os
......@@ -358,6 +359,51 @@ def test_conv_fc_interface(is_parallel=parallel, model=None, zeros_mask_dict=Non
run_forward_backward(model, optimizer, dummy_input)
def test_threshold_mask():
# Create a 4-D tensor of 1s
a = torch.ones(3, 64, 32, 32)
# Change one element
a[1, 4, 17, 31] = 0.2
# Create and apply a mask
mask = distiller.threshold_mask(a, threshold=0.3)
assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
assert mask[1, 4, 17, 31] == 0
assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
def test_magnitude_pruning():
# Create a 4-D tensor of 1s
a = torch.ones(3, 64, 32, 32)
# Change one element
a[1, 4, 17, 31] = 0.2
# Create a masks dictionary and populate it with one ParameterMasker
zeros_mask_dict = {}
masker = distiller.ParameterMasker('a')
zeros_mask_dict['a'] = masker
# Try to use a MagnitudeParameterPruner with defining a default threshold
with pytest.raises(AssertionError):
pruner = distiller.pruning.MagnitudeParameterPruner("test", None)
# Now define the default threshold
thresholds = {"*": 0.4}
pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds)
assert distiller.sparsity(a) == 0
# Create a mask for parameter 'a'
pruner.set_param_mask(a, 'a', zeros_mask_dict, None)
assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a))
# Let's now use the masker to prune a parameter
masker = zeros_mask_dict['a']
masker.apply_mask(a)
assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a))
# We can use the masker on other tensors, if we want (and if they have the correct shape).
# Remember that the mask was created already, so we're not thresholding - we are pruning
b = torch.ones(3, 64, 32, 32)
b[:] = 0.3
masker.apply_mask(b)
assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
if __name__ == '__main__':
for is_parallel in [True, False]:
test_ranked_filter_pruning(is_parallel)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment