-
- Downloads
Bug fix: add support for thinning the optimizer
You no longer need to use —momentum=0 when removing structures dynamically. The SGD momentum update (velocity) is dependent on the weights, which PyTorch optimizers cache internally. This caching is not a problem for filter/channel removal (thinning) because although we dynamically change the shapes of the weights tensors, we don’t change the weights tensors themselves. PyTorch’s SGD creates tensors to store the momentum updates, and these tensors have the same shape as the weights tensors. When we change the weights tensors, we need to make the appropriate changes in the Optimizer, or disable the momentum. We added a new function - thinning.optimizer_thinning() - to do this. This function is brittle as it is tested only on optim.SGD and relies on the internal representation of the SGD optimizer, which can change w/o notice. For example, optim.Adam uses state['exp_avg'], state['exp_avg_sq'] Which also depend the shape of the weight tensors. We needed to pass the Optimizer instance to Thinning policies (ChannelRemover, FilterRemover) via the callbacks, which required us to change the callback interface. In the future we plan a bigger change to the callback API, to allow passing of arbitrary context from the training environment to Distiller. Also in this commit: * compress_classifier.py had special handling for resnet layer-removal, which is used in examples/ssl/ssl_4D-removal_training.yaml. This is a brittle and ugly hack. Until we have a more elegant solution, I’m Removing support for layer-removal. * Added to the tests invocation of forward and backward passes over a model. This tests more of the real flows, which use the optimizer and construct gradient tensors. * Added a test of a special case of convolution filter-pruning which occurs when the next layer is fully-connected (linear)
Showing
- distiller/policy.py 7 additions, 7 deletionsdistiller/policy.py
- distiller/scheduler.py 12 additions, 11 deletionsdistiller/scheduler.py
- distiller/thinning.py 57 additions, 23 deletionsdistiller/thinning.py
- examples/classifier_compression/compress_classifier.py 7 additions, 11 deletionsexamples/classifier_compression/compress_classifier.py
- examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml 1 addition, 1 deletion...rs_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
- examples/ssl/ssl_4D-removal_training.yaml 10 additions, 0 deletionsexamples/ssl/ssl_4D-removal_training.yaml
- tests/test_pruning.py 161 additions, 66 deletionstests/test_pruning.py
Loading
Please register or sign in to comment