From b40dff5ed005504ad1c5d9ada84ee8c4ae4ebb56 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 18 Oct 2018 20:01:15 +0300 Subject: [PATCH] Bug fix: exporting Alexnet and VGG models to ONNX ONNX export in PyTorch doesn't know how to handle DataParallel layers, so we need to make sure that we remove all instances of nn.DataParallel from the model before exporting it. The previous ONNX implementation forgot to deal with the case of DataParallel layers that do not wrap the entire model (as in VGG, where only the feature-extractor layers are data-parallel). --- apputils/model_summaries.py | 26 +++--- distiller/utils.py | 178 ++++++++++++++++++++++++++++++------ 2 files changed, 159 insertions(+), 45 deletions(-) diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 3d0cedc..6c2cda3 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -606,30 +606,26 @@ def dataset_dummy_input(dataset): return dummy_input -def export_img_classifier_to_onnx(model, onnx_fname, dataset): +def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True, add_softmax=True): """Export a PyTorch image classifier to ONNX. """ dummy_input = dataset_dummy_input(dataset).to('cuda') + # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel + model = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model, False): - # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel - if isinstance(model, torch.nn.DataParallel): - model = model.module - - # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. - # We make a copy of the model, since we are about to change it (adding softmax). - model = deepcopy(model) - model.original_forward = model.forward - softmax = torch.nn.Softmax(dim=1) - model.forward = lambda input: softmax(model.original_forward(input)) - - torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=True) + if add_softmax: + # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. + model.original_forward = model.forward + softmax = torch.nn.Softmax(dim=1) + model.forward = lambda input: softmax(model.original_forward(input)) + torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params) msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname)) - def data_node_has_parent(g, id): for edge in g.edges: - if edge.dst == id: return True + if edge.dst == id: + return True return False diff --git a/distiller/utils.py b/distiller/utils.py index 13d1f4a..a193dfc 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -50,6 +50,38 @@ def pretty_int(i): return "{:,}".format(i) +def assign_layer_fq_names(container, name=None): + """Assign human-readable names to the modules (layers). + + Sometimes we need to access modules by their names, and we'd like to use + fully-qualified names for convinience. + """ + is_leaf = True + for key, module in container._modules.items(): + is_leaf = False + assign_layer_fq_names(module, ".".join([name, key]) if name is not None else key) + if is_leaf: + container.distiller_name = name + + +def find_module_by_fq_name(model, fq_mod_name): + """Given a module's fully-qualified name, find the module in the provided model. + + A fully-qualified name is assigned to modules in function assign_layer_fq_names. + + Arguments: + model: the model to search + fq_mod_name: the module whose name we want to look up + + Returns: + The module or None, if the module was not found. + """ + for module in model.modules(): + if hasattr(module, 'distiller_name') and fq_mod_name == module.distiller_name: + return module + return None + + def normalize_module_name(layer_name): """Normalize a module's name. @@ -94,7 +126,6 @@ def density(tensor): Returns: density (float) """ - assert torch.numel(tensor) > 0 nonzero = torch.nonzero(tensor) if nonzero.dim() == 0: return 0.0 @@ -193,34 +224,60 @@ def density_ch(tensor): return 1 - sparsity_ch(tensor) -def sparsity_cols(tensor): - """Column-wise sparsity for 2D tensors""" +def sparsity_matrix(tensor, dim): + """Generic sparsity computation for 2D matrices""" if tensor.dim() != 2: return 0 - num_cols = tensor.size()[1] - nonzero_cols = len(torch.nonzero(tensor.abs().sum(dim=0))) - return 1 - nonzero_cols/num_cols + num_structs = tensor.size()[dim] + nonzero_structs = len(torch.nonzero(tensor.abs().sum(dim=1-dim))) + return 1 - nonzero_structs/num_structs -def density_cols(tensor): +def sparsity_cols(tensor, trasposed=True): + """Column-wise sparsity for 2D tensors + + PyTorch GEMM matrices are transposed before they are used in the GEMM operation. + In other words the matrices are stored in memory transposed. So by default we compute + the sparsity of the transposed dimension. + """ + if trasposed: + return sparsity_matrix(tensor, 0) + return sparsity_matrix(tensor, 1) + + +def density_cols(tensor, transposed=True): """Column-wise density for 2D tensors""" - return 1 - sparsity_cols(tensor) + return 1 - sparsity_cols(tensor, transposed) -def sparsity_rows(tensor): - """Row-wise sparsity for 2D matrices""" - if tensor.dim() != 2: - return 0 +def sparsity_rows(tensor, trasposed=True): + """Row-wise sparsity for 2D matrices - num_rows = tensor.size()[0] - nonzero_rows = len(torch.nonzero(tensor.abs().sum(dim=1))) - return 1 - nonzero_rows/num_rows + PyTorch GEMM matrices are transposed before they are used in the GEMM operation. + In other words the matrices are stored in memory transposed. So by default we compute + the sparsity of the transposed dimension. + """ + if trasposed: + return sparsity_matrix(tensor, 1) + return sparsity_matrix(tensor, 0) -def density_rows(tensor): +def density_rows(tensor, transposed=True): """Row-wise density for 2D tensors""" - return 1 - sparsity_rows(tensor) + return 1 - sparsity_rows(tensor, transposed) + + +def norm_filters(weights, p=1): + """Compute the p-norm of convolution filters. + + Args: + weights - a 4D convolution weights tensor. + Has shape = (#filters, #channels, k_w, k_h) + p - the exponent value in the norm formulation + """ + assert weights.dim() == 4 + return weights.view(weights.size(0), -1).norm(p=p, dim=1) def model_numel(model, param_dims=[2, 4]): @@ -233,6 +290,70 @@ def model_numel(model, param_dims=[2, 4]): return total_numel +def activation_channels_l1(activation): + """Calculate the L1-norms of an activation's channels. + + The activation usually has the shape: (batch_size, num_channels, h, w). + + When the activations are computed on a distributed GPU system, different parts of the + activation tensor might be computed by a differnt GPU. If this function is called from + the forward-callback of some activation module in the graph, we will only witness part + of the batch. For example, if the batch_size is 256, and we are using 4 GPUS, instead + of seeing activations with shape = (256, num_channels, h, w), we may see 4 calls with + shape = (64, num_channels, h, w). + + Since we want to calculate the average of the L1-norm of each of the channels of the + activation, we need to move the partial sums results to the CPU, where they will be + added together. + + Returns - for each channel: the batch-mean of its L1 magnitudes (i.e. over all of the + activations in the mini-batch, compute the mean of the L! magnitude of each channel). + """ + view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w) + featuremap_norms = view_2d.norm(p=1, dim=1) + featuremap_norms_mat = featuremap_norms.view(activation.size(0), activation.size(1)) # batch x channel + # We need to move the results back to the CPU + return featuremap_norms_mat.mean(dim=0).cpu() + + +def activation_channels_means(activation): + """Calculate the mean of each of an activation's channels. + + The activation usually has the shape: (batch_size, num_channels, h, w). + + "We first use global average pooling to convert the output of layer i, which is a + c x h x w tensor, into a 1 x c vector." + + Returns - for each channel: the batch-mean of its L1 magnitudes (i.e. over all of the + activations in the mini-batch, compute the mean of the L1 magnitude of each channel). + """ + view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w) + featuremap_means = sparsity_rows(view_2d) + featuremap_means_mat = featuremap_means.view(activation.size(0), activation.size(1)) # batch x channel + # We need to move the results back to the CPU + return featuremap_means_mat.mean(dim=0).cpu() + + +def activation_channels_apoz(activation): + """Calculate the APoZ of each of an activation's channels. + + APoZ is the Average Percentage of Zeros (or simply: average sparsity) and is defined in: + "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures". + + The activation usually has the shape: (batch_size, num_channels, h, w). + + "We first use global average pooling to convert the output of layer i, which is a + c x h x w tensor, into a 1 x c vector." + + Returns - for each channel: the batch-mean of its sparsity. + """ + view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w) + featuremap_means = view_2d.mean(dim=1) # global average pooling + featuremap_means_mat = featuremap_means.view(activation.size(0), activation.size(1)) # batch x channel + # We need to move the results back to the CPU + return featuremap_means_mat.mean(dim=0).cpu() + + def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers): """Log information about the training progress, and the distribution of the weight tensors. @@ -258,10 +379,12 @@ def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total logger.log_weights_distribution(params_dict, steps_completed) -def log_activation_sparsity(epoch, loggers, collector): +def log_activation_statsitics(epoch, phase, loggers, collector): """Log information about the sparsity of the activations""" + if collector is None: + return for logger in loggers: - logger.log_activation_sparsity(collector.value(), epoch) + logger.log_activation_statsitic(phase, collector.stat_name, collector.value(), epoch) def log_weights_sparsity(model, epoch, loggers): @@ -298,24 +421,19 @@ class DoNothingModuleWrapper(nn.Module): def make_non_parallel_copy(model): """Make a non-data-parallel copy of the provided model. - nn.DataParallel instances are replaced by DoNothingModuleWrapper - instances. + torch.nn.DataParallel instances are removed. """ - def replace_data_parallel(container, prefix=''): + def replace_data_parallel(container): for name, module in container.named_children(): - full_name = prefix + name if isinstance(module, nn.DataParallel): - # msglogger.debug('Replacing module {}'.format(full_name)) - setattr(container, name, DoNothingModuleWrapper(module.module)) + setattr(container, name, module.module) if has_children(module): - # For a container we call recursively - replace_data_parallel(module, full_name + '.') + replace_data_parallel(module) # Make a copy of the model, because we're going to change it new_model = deepcopy(model) if isinstance(new_model, nn.DataParallel): - # new_model = new_model.module # - new_model = DoNothingModuleWrapper(new_model.module) - + new_model = new_model.module replace_data_parallel(new_model) + return new_model -- GitLab