diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 2d2dab63b81f4bd1fe9b05e26bee75bf7d58d019..7c0b09d0e0d4aaa509f831e9eb17abc1316f567d 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -93,12 +93,13 @@ class SummaryGraph(object): Edge = collections.namedtuple('Edge', 'src dst') def __init__(self, model, dummy_input): - model = distiller.make_non_parallel_copy(model) - with torch.onnx.set_training(model, False): + self._src_model = model + model_clone = distiller.make_non_parallel_copy(model) + with torch.onnx.set_training(model_clone, False): - device = next(model.parameters()).device + device = next(model_clone.parameters()).device dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) - trace, _ = jit.get_trace_graph(model, dummy_input) + trace, _ = jit.get_trace_graph(model_clone, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. @@ -152,7 +153,7 @@ class SummaryGraph(object): self.add_macs_attr() self.add_footprint_attr() self.add_arithmetic_intensity_attr() - del model + del model_clone def __create_op(self, onnx_node): op = {} @@ -266,15 +267,13 @@ class SummaryGraph(object): return [op for op in self.ops.values() if attr in op['attrs'] and f(op)] def find_op(self, lost_op_name): - assert isinstance(lost_op_name, str) - return self.ops.get(lost_op_name, None) + return self.ops.get(distiller.normalize_module_name(lost_op_name), None) def find_param(self, data_name): return self.params.get(data_name, None) def predecessors(self, op, depth, done_list=None): """Returns a list of <op>'s predecessors""" - if done_list is None: done_list = [] @@ -288,16 +287,18 @@ class SummaryGraph(object): done_list += preds if depth == 1: - return preds + ret = preds else: ret = [] for predecessor in preds: ret += self.predecessors(predecessor, depth-1, done_list) - return ret + + return [distiller.denormalize_module_name(self._src_model, x) for x in ret] def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria. """ + node_name = distiller.normalize_module_name(node_name) node = self.find_op(node_name) node_is_an_op = True if node is None: @@ -319,7 +320,7 @@ class SummaryGraph(object): # We check if we found the type of node we're looking for, # and that this is not the first node in our search. if node['type'] in predecessors_types and len(done_list) > 1: - return [node_name] + return [distiller.denormalize_module_name(self._src_model, node_name)] # This is an operation node preds = [edge.src for edge in self.edges if (edge.dst == node_name and @@ -331,11 +332,11 @@ class SummaryGraph(object): ret = [] for predecessor in preds: ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging) - return ret + + return [distiller.denormalize_module_name(self._src_model, node) for node in ret] def successors(self, node, depth, done_list=None): """Returns a list of <op>'s successors""" - if done_list is None: done_list = [] @@ -351,12 +352,13 @@ class SummaryGraph(object): done_list += succs if depth == 1: - return succs + ret = succs else: ret = [] for successor in succs: ret += self.successors(successor, depth-1, done_list) - return ret + + return [distiller.denormalize_module_name(self._src_model, x) for x in ret] def successors_f(self, node_name, successors_types, done_list=None, logging=None): """Returns a list of <op>'s successors, if they match the <successors_types> criteria. @@ -367,7 +369,7 @@ class SummaryGraph(object): <node_name> and the returned list of successors are strings, because """ - + node_name = distiller.normalize_module_name(node_name) node = self.find_op(node_name) node_is_an_op = True if node is None: @@ -389,7 +391,7 @@ class SummaryGraph(object): # We check if we found the type of node we're looking for, # and that this is not the first node in our search. if node['type'] in successors_types and len(done_list) > 1: - return [node_name] + return [distiller.denormalize_module_name(self._src_model, node_name)] # This is an operation node succs = [edge.dst for edge in self.edges if (edge.src == node_name and @@ -401,4 +403,15 @@ class SummaryGraph(object): ret = [] for successor in succs: ret += self.successors_f(successor, successors_types, done_list, logging) - return ret + + return [distiller.denormalize_module_name(self._src_model, node) for node in ret] + + def named_params_layers(self): + for param_name, param in self._src_model.named_parameters(): + # remove the extension of param_name, and then normalize it + # to create a normalized layer name + normalized_layer_name = distiller.normalize_module_name( + '.'.join(param_name.split('.')[:-1])) + sgraph_layer_name = distiller.denormalize_module_name( + self._src_model, normalized_layer_name) + yield sgraph_layer_name, param_name, param diff --git a/distiller/thinning.py b/distiller/thinning.py index 78cd2879469b1c62f40094f2702f4e4e2e6ecedb..0ca0e2499ce749dc8e8e65a37689f0c5a4b3941a 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -31,8 +31,6 @@ from collections import namedtuple import torch from .policy import ScheduledTrainingPolicy import distiller -from distiller import normalize_module_name, denormalize_module_name -from distiller.models import create_model from .summary_graph import SummaryGraph msglogger = logging.getLogger(__name__) @@ -64,34 +62,23 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'execute_thinning_recipes_list', 'get_normalized_recipe'] -def create_graph(dataset, arch): +def create_graph(dataset, model): + dummy_input = None if dataset == 'imagenet': dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) elif dataset == 'cifar10': dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - model = create_model(False, dataset, arch, parallel=False) - assert model is not None dummy_input = dummy_input.to(distiller.model_device(model)) return SummaryGraph(model, dummy_input) def get_normalized_recipe(recipe): - new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()}, - parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()}) - return new_recipe - - -def param_name_2_layer_name(param_name): - """Convert a weights tensor's name to the name of the layer using the tensor. - - By convention, PyTorch modules name their weights parameters as self.weight - (see for example: torch.nn.modules.conv) which means that their fully-qualified - name when enumerating a model's parameters is the modules name followed by '.weight'. - We exploit this convention to convert a weights tensor name to the fully-qualified - module name.""" - return param_name[:-len('.weight')] + return ThinningRecipe( + modules={distiller.normalize_module_name(k): v for k, v in recipe.modules.items()}, + parameters={distiller.normalize_module_name(k): v for k, v in recipe.parameters.items()}, + ) def directives_equal(d1, d2): @@ -120,9 +107,8 @@ def append_param_directive(thinning_recipe, param_name, directive): thinning_recipe.parameters[param_name] = param_directives -def append_module_directive(model, thinning_recipe, module_name, key, val): +def append_module_directive(thinning_recipe, module_name, key, val): msglogger.debug("\t[recipe] setting {}.{} = {}".format(module_name, key, val)) - module_name = denormalize_module_name(model, module_name) mod_directive = thinning_recipe.modules.get(module_name, {}) mod_directive[key] = val thinning_recipe.modules[module_name] = mod_directive @@ -180,7 +166,7 @@ def resnet_cifar_remove_layers(model): def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = create_graph(dataset, arch) + sgraph = create_graph(dataset, model) thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model @@ -234,7 +220,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = create_graph(dataset, arch) + sgraph = create_graph(dataset, model) thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model @@ -256,7 +242,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Traverse all of the model's parameters, search for zero-channels, and # create a thinning recipe that descibes the required changes to the model. - for param_name, param in model.named_parameters(): + for layer_name, param_name, param in sgraph.named_params_layers(): # We are only interested in 4D weights (of Convolution layers) if param.dim() != 4: continue @@ -272,43 +258,35 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # We are removing channels, so update the number of incoming channels (IFMs) # in the convolutional layer - layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(model, thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) + append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) # Select only the non-zero filters indices = nonzero_channels.data.squeeze() append_param_directive(thinning_recipe, param_name, (1, indices)) # Find all instances of Convolution layers that immediately preceed this layer - predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv']) - # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used) - predecessors = [normalize_module_name(predecessor) for predecessor in predecessors] + predecessors = sgraph.predecessors_f(layer_name, ['Conv']) if len(predecessors) == 0: - msglogger.info("Could not find predecessors for name={} normal={} {}".format( - layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name))) + msglogger.info("Could not find predecessors for name={}".format(layer_name)) for predecessor in predecessors: # For each of the convolutional layers that preceed, we have to reduce the number of output channels. - append_module_directive(model, thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) + append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) # Now remove channels from the weights tensor of the predecessor conv - append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.weight', (0, indices)) + append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) - if layers[denormalize_module_name(model, predecessor)].bias is not None: + if layers[predecessor].bias is not None: # This convolution has bias coefficients - append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.bias', (0, indices)) + append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) # Now handle the BatchNormalization layer that follows the convolution - bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization']) - if len(bn_layers) > 0: - # if len(bn_layers) != 1: - # raise RuntimeError("{} should have exactly one BN predecessors, but has {}".format(layer_name, len(bn_layers))) - for bn_layer in bn_layers: - # Thinning of the BN layer that follows the convolution - bn_layer_name = denormalize_module_name(model, bn_layer) - msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer_name)) - append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name, - len_thin_features=num_nnz_channels, thin_features=indices) + bn_layers = sgraph.predecessors_f(layer_name, ['BatchNormalization']) + for bn_layer in bn_layers: + # Thinning of the BN layer that follows the convolution + msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer)) + append_bn_thinning_directive(thinning_recipe, layers, bn_layer, + len_thin_features=num_nnz_channels, thin_features=indices) msglogger.debug(thinning_recipe) return thinning_recipe @@ -329,7 +307,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): thinning_recipe = ThinningRecipe(modules={}, parameters={}) layers = {mod_name: m for mod_name, m in model.named_modules()} - for param_name, param in model.named_parameters(): + for layer_name, param_name, param in sgraph.named_params_layers(): # We are only interested in 4D weights if param.dim() != 4: continue @@ -343,7 +321,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name) # If there are non-zero filters in this tensor then continue to next tensor if num_filters <= num_nnz_filters: - msglogger.debug("Skipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape)) + msglogger.debug("Skipping {} shape={}".format(param_name, param.shape)) continue msglogger.info("In tensor %s found %d/%d zero filters", param_name, @@ -351,9 +329,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): # We are removing filters, so update the number of outgoing channels (OFMs) # in the convolutional layer - layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(model, thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) + append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) # Select only the non-zero filters indices = nonzero_filters.data.squeeze() @@ -364,24 +341,20 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices)) # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer - msglogger.debug("{} => {}".format(layer_name, normalize_module_name(layer_name))) - successors = sgraph.successors_f(normalize_module_name(layer_name), ['Conv', 'Gemm']) - # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used) - successors = [denormalize_module_name(model, successor) for successor in successors] + successors = sgraph.successors_f(layer_name, ['Conv', 'Gemm']) for successor in successors: - if isinstance(layers[successor], torch.nn.modules.Conv2d): # For each of the convolutional layers that follow, we have to reduce the number of input channels. - append_module_directive(model, thinning_recipe, successor, key='in_channels', val=num_nnz_filters) + append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters) # Now remove channels from the weights tensor of the successor conv - append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices)) + append_param_directive(thinning_recipe, successor+'.weight', (1, indices)) elif isinstance(layers[successor], torch.nn.modules.Linear): # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member fm_size = layers[successor].in_features // layers[layer_name].out_channels in_features = fm_size * num_nnz_filters - append_module_directive(model, thinning_recipe, successor, key='in_features', val=in_features) + append_module_directive(thinning_recipe, successor, key='in_features', val=in_features) msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format( successor, in_features, layer_name, layers[layer_name].out_channels)) msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) @@ -391,18 +364,16 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): fm_height = fm_width = int(math.sqrt(fm_size)) view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) view_2D = (layers[successor].out_features, in_features) - append_param_directive(thinning_recipe, - denormalize_module_name(model, successor)+'.weight', + append_param_directive(thinning_recipe, successor+'.weight', (1, indices, view_4D, view_2D)) # Now handle the BatchNormalization layer that follows the convolution - bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization']) + bn_layers = sgraph.successors_f(layer_name, ['BatchNormalization']) if len(bn_layers) > 0: assert len(bn_layers) == 1 # Thinning of the BN layer that follows the convolution - bn_layer_name = denormalize_module_name(model, bn_layers[0]) - append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name, - len_thin_features=num_nnz_filters, thin_features=indices) + append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0], + len_thin_features=num_nnz_filters, thin_features=indices) return thinning_recipe diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 0d378fa9f18df827cbde0af82ce785f6b6d71d7d..195a9825adba8ccc9cdfb5f0f65f471834170d72 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -161,6 +161,23 @@ def test_normalize_module_name(): name_test('imagenet', 'alexnet') +def named_params_layers_test_aux(dataset, arch, dataparallel:bool): + model = create_model(False, dataset, arch, parallel=dataparallel) + sgraph = SummaryGraph(model, get_input(dataset)) + sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) + for layer_name in sgraph_layer_names: + assert (sgraph.find_op(layer_name) is not None, + '{} was not found in summary graph'.format(layer_name)) + + +def test_named_params_layers(): + for dataParallelModel in (True, False): + named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel) + named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel) + named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel) + named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel) + + def test_onnx_name_2_pytorch_name(): assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu') assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')