Skip to content
Snippets Groups Projects
Commit c59d2d51 authored by Neta Zmora's avatar Neta Zmora
Browse files

Module name normalization: bug fixes, refactoring, tests

Fixed a bug in module name normalization, for modules with
a name ending in ".module" (e.g. "features.module" in the case of
VGG).
Made the tests more robust, and also refactored the common code
to distiller/utils.py
parent 9e57219e
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ from collections import namedtuple
import torch
from .policy import ScheduledTrainingPolicy
import distiller
from distiller import normalize_module_name, denormalize_module_name
from apputils import SummaryGraph
from models import ALL_MODEL_NAMES, create_model
msglogger = logging.getLogger()
......@@ -69,6 +70,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
'find_nonzero_channels',
'execute_thinning_recipes_list']
def create_graph(dataset, arch):
if dataset == 'imagenet':
dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
......@@ -81,23 +83,6 @@ def create_graph(dataset, arch):
return SummaryGraph(model, dummy_input)
def normalize_layer_name(layer_name):
start = layer_name.find('module.')
normalized_layer_name = layer_name
if start != -1:
normalized_layer_name = layer_name[:start] + layer_name[start + len('module.'):]
return normalized_layer_name
def denormalize_layer_name(model, normalized_name):
"""Convert back from the normalized form of the layer name, to PyTorch's name
which contains "artifacts" if DataParallel is used.
"""
ugly = [mod_name for mod_name, _ in model.named_modules() if normalize_layer_name(mod_name) == normalized_name]
assert len(ugly) == 1
return ugly[0]
def param_name_2_layer_name(param_name):
return param_name[:-len('weights')]
......@@ -262,9 +247,9 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
append_param_directive(thinning_recipe, param_name, (1, indices))
# Find all instances of Convolution layers that immediately preceed this layer
predecessors = sgraph.predecessors_f(normalize_layer_name(layer_name), ['Conv'])
predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv'])
# Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
predecessors = [denormalize_layer_name(model, predecessor) for predecessor in predecessors]
predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors]
for predecessor in predecessors:
# For each of the convolutional layers that preceed, we have to reduce the number of output channels.
append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels))
......@@ -273,11 +258,11 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
# Now handle the BatchNormalization layer that follows the convolution
bn_layers = sgraph.predecessors_f(normalize_layer_name(layer_name), ['BatchNormalization'])
bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization'])
if len(bn_layers) > 0:
assert len(bn_layers) == 1
# Thinning of the BN layer that follows the convolution
bn_layer_name = denormalize_layer_name(model, bn_layers[0])
bn_layer_name = denormalize_module_name(model, bn_layers[0])
bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_channels), thin_features=indices)
return thinning_recipe
......@@ -331,9 +316,9 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices))
# Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer
successors = sgraph.successors_f(normalize_layer_name(layer_name), ['Conv', 'Gemm'])
successors = sgraph.successors_f(normalize_module_name(layer_name), ['Conv', 'Gemm'])
# Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
successors = [denormalize_layer_name(model, successor) for successor in successors]
successors = [denormalize_module_name(model, successor) for successor in successors]
for successor in successors:
if isinstance(layers[successor], torch.nn.modules.Conv2d):
......@@ -360,11 +345,11 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
append_param_directive(thinning_recipe, successor+'.weight', (1, indices, view_4D, view_2D))
# Now handle the BatchNormalization layer that follows the convolution
bn_layers = sgraph.successors_f(normalize_layer_name(layer_name), ['BatchNormalization'])
bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization'])
if len(bn_layers) > 0:
assert len(bn_layers) == 1
# Thinning of the BN layer that follows the convolution
bn_layer_name = denormalize_layer_name(model, bn_layers[0])
bn_layer_name = denormalize_module_name(model, bn_layers[0])
bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_filters), thin_features=indices)
return thinning_recipe
......@@ -412,6 +397,7 @@ class FilterRemover(ScheduledTrainingPolicy):
# The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused
self.done = False
def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
# Invoke this function when you want to use a list of thinning recipes to convert a programmed model
# to a thinned model. For example, this is invoked when loading a model from a checkpoint.
......@@ -419,6 +405,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
msglogger.info("recipe %d:" % i)
execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True)
def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
"""Apply a thinning recipe to a model.
......
......@@ -24,14 +24,17 @@ import numpy as np
import torch
from torch.autograd import Variable
def to_np(var):
return var.data.cpu().numpy()
def to_var(tensor, cuda=True):
if cuda and torch.cuda.is_available():
tensor = tensor.cuda()
return Variable(tensor)
def size2str(torch_size):
if isinstance(torch_size, torch.Size):
return size_to_str(torch_size)
......@@ -41,15 +44,45 @@ def size2str(torch_size):
return size_to_str(torch_size.data.size())
raise TypeError
def size_to_str(torch_size):
"""Convert a pytorch Size object to a string"""
assert isinstance(torch_size, torch.Size)
return '('+(', ').join(['%d' % v for v in torch_size])+')'
def normalize_module_name(layer_name):
"""Normalize a module's name.
PyTorch let's you parallelize the computation of a model, by wrapping a model with a
DataParallel module. Unfortunately, this changs the fully-qualified name of a module,
even though the actual functionality of the module doesn't change.
Many time, when we search for modules by name, we are indifferent to the DataParallel
module and want to use the same module name whether the module is parallel or not.
We call this module name normalization, and this is implemented here.
"""
if layer_name.find("module.") >= 0:
return layer_name.replace("module.", "")
return layer_name.replace(".module", "")
def denormalize_module_name(parallel_model, normalized_name):
"""Convert back from the normalized form of the layer name, to PyTorch's name
which contains "artifacts" if DataParallel is used.
"""
fully_qualified_name = [mod_name for mod_name, _ in parallel_model.named_modules() if
normalize_module_name(mod_name) == normalized_name]
if len(fully_qualified_name) > 0:
return fully_qualified_name[-1]
else:
return "" # Did not find a module with the name <normalized_name>
def volume(tensor):
"""return the volume of a pytorch tensor"""
return np.prod(tensor.shape)
def density(tensor):
"""Computes the density of a tensor.
......@@ -84,6 +117,7 @@ def sparsity(tensor):
"""
return 1.0 - density(tensor)
def sparsity_3D(tensor):
"""Filter-wise sparsity for 4D tensors"""
if tensor.dim() != 4:
......@@ -93,10 +127,12 @@ def sparsity_3D(tensor):
nonzero_filters = len(torch.nonzero(view_3d.abs().sum(dim=1)))
return 1 - nonzero_filters/num_filters
def density_3D(tensor):
"""Filter-wise density for 4D tensors"""
return 1 - sparsity_3D(tensor)
def sparsity_2D(tensor):
"""Create a list of sparsity levels for each channel in the tensor 't'
......@@ -128,10 +164,12 @@ def sparsity_2D(tensor):
nonzero_structs = len(torch.nonzero(view_2d.abs().sum(dim=1)))
return 1 - nonzero_structs/num_structs
def density_2D(tensor):
"""Kernel-wise sparsity for 4D tensors"""
return 1 - sparsity_2D(tensor)
def sparsity_ch(tensor):
"""Channel-wise sparsity for 4D tensors"""
if tensor.dim() != 4:
......@@ -150,10 +188,12 @@ def sparsity_ch(tensor):
nonzero_channels = len(torch.nonzero(k_sums_mat.abs().sum(dim=1)))
return 1 - nonzero_channels/num_kernels_per_filter
def density_ch(tensor):
"""Channel-wise density for 4D tensors"""
return 1 - sparsity_ch(tensor)
def sparsity_cols(tensor):
"""Column-wise sparsity for 2D tensors"""
if tensor.dim() != 2:
......@@ -163,10 +203,12 @@ def sparsity_cols(tensor):
nonzero_cols = len(torch.nonzero(tensor.abs().sum(dim=0)))
return 1 - nonzero_cols/num_cols
def density_cols(tensor):
"""Column-wise density for 2D tensors"""
return 1 - sparsity_cols(tensor)
def sparsity_rows(tensor):
"""Row-wise sparsity for 2D matrices"""
if tensor.dim() != 2:
......@@ -176,10 +218,12 @@ def sparsity_rows(tensor):
nonzero_rows = len(torch.nonzero(tensor.abs().sum(dim=1)))
return 1 - nonzero_rows/num_rows
def density_rows(tensor):
"""Row-wise density for 2D tensors"""
return 1 - sparsity_rows(tensor)
def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers):
"""Log information about the training progress, and the distribution of the weight tensors.
......
......@@ -26,6 +26,7 @@ import distiller
import pytest
from models import ALL_MODEL_NAMES, create_model
from apputils import SummaryGraph, onnx_name_2_pytorch_name
from distiller import normalize_module_name, denormalize_module_name
# Logging configuration
logging.basicConfig(level=logging.DEBUG)
......@@ -124,13 +125,6 @@ def test_layer_search():
assert preds == ['layer1.0.conv2', 'conv1']
def normalize_layer_name(layer_name):
start = layer_name.find('module.')
if start != -1:
layer_name = layer_name[:start] + layer_name[start + len('module.'):]
return layer_name
def test_vgg():
g = create_graph('imagenet', 'vgg19')
assert g is not None
......@@ -139,10 +133,29 @@ def test_vgg():
succs = g.successors_f('features.34', 'Conv')
def test_normalize_layer_name():
assert "features.0", normalize_layer_name("features.module.0")
assert "features.0", normalize_layer_name("module.features.0")
assert "features.0", normalize_layer_name("features.0.module")
def name_test(dataset, arch):
model = create_model(False, dataset, arch, parallel=False)
modelp = create_model(False, dataset, arch, parallel=True)
assert model is not None and modelp is not None
mod_names = [mod_name for mod_name, _ in model.named_modules()]
mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
assert mod_names is not None and mod_names_p is not None
assert len(mod_names)+1 == len(mod_names_p)
for i in range(len(mod_names)-1):
assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
def test_normalize_module_name():
assert "features.0" == normalize_module_name("features.module.0")
assert "features.0" == normalize_module_name("module.features.0")
assert "features" == normalize_module_name("features.module")
name_test('imagenet', 'vgg19')
name_test('cifar10', 'resnet20_cifar')
name_test('imagenet', 'alexnet')
def test_onnx_name_2_pytorch_name():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment