Skip to content
Snippets Groups Projects
  • Neta Zmora's avatar
    c849a25f
    Pruning with virtual Batch-norm statistics folding (#415) · c849a25f
    Neta Zmora authored
    * pruning: add an option to virtually fold BN into Conv2D for ranking
    
    PruningPolicy can be configured using a new control argument fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the weights of Conv-2D modules (if Conv2D->BN edges exist in the model graph).  Each weights filter is attenuated using a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and filter-ranking pruning methods.  We attenuate using the running values of the mean and variance, as is done in quantization.
    This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and Linear operations are not supported).
    e.g.:
    policies:
      - pruner:
          instance_name : low_pruner
          args:
            fold_batchnorm: True
        starting_epoch: 0
        ending_epoch: 30
        frequency: 2
    
    * AGP: non-functional refactoring
    
    distiller/pruning/automated_gradual_pruner.py – change `prune_to_target_sparsity`
    to `_set_param_mask_by_sparsity_target`, which is a more appropriate function
    name as we don’t really prune in this function
    
    * Simplify GEMM weights input-channel ranking logic
    
    Ranking weight-matrices by input channels is similar to ranking 4D
    Conv weights by input channels, so there is no need for duplicate logic.
    
    distiller/pruning/ranked_structures_pruner.py
    -change `prune_to_target_sparsity` to `_set_param_mask_by_sparsity_target`,
    which is a more appropriate function name as we don’t really prune in this
    function
    -remove the code handling ranking of matrix rows
    
    distiller/norms.py – remove rank_cols.
    
    distiller/thresholding.py – in expand_binary_map treat `channels` group_type
    the same as the `cols` group_type when dealing with 2D weights
    
    * AGP: add example of ranking filters with virtual BN-folding
    
    Also update resnet20 AGP examples
    Pruning with virtual Batch-norm statistics folding (#415)
    Neta Zmora authored
    * pruning: add an option to virtually fold BN into Conv2D for ranking
    
    PruningPolicy can be configured using a new control argument fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the weights of Conv-2D modules (if Conv2D->BN edges exist in the model graph).  Each weights filter is attenuated using a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and filter-ranking pruning methods.  We attenuate using the running values of the mean and variance, as is done in quantization.
    This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and Linear operations are not supported).
    e.g.:
    policies:
      - pruner:
          instance_name : low_pruner
          args:
            fold_batchnorm: True
        starting_epoch: 0
        ending_epoch: 30
        frequency: 2
    
    * AGP: non-functional refactoring
    
    distiller/pruning/automated_gradual_pruner.py – change `prune_to_target_sparsity`
    to `_set_param_mask_by_sparsity_target`, which is a more appropriate function
    name as we don’t really prune in this function
    
    * Simplify GEMM weights input-channel ranking logic
    
    Ranking weight-matrices by input channels is similar to ranking 4D
    Conv weights by input channels, so there is no need for duplicate logic.
    
    distiller/pruning/ranked_structures_pruner.py
    -change `prune_to_target_sparsity` to `_set_param_mask_by_sparsity_target`,
    which is a more appropriate function name as we don’t really prune in this
    function
    -remove the code handling ranking of matrix rows
    
    distiller/norms.py – remove rank_cols.
    
    distiller/thresholding.py – in expand_binary_map treat `channels` group_type
    the same as the `cols` group_type when dealing with 2D weights
    
    * AGP: add example of ranking filters with virtual BN-folding
    
    Also update resnet20 AGP examples