From ef3e7415625f3aeb8300c4147b0dd551fd2decf7 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 19 Aug 2019 10:09:20 +0300 Subject: [PATCH] model_transforms.py: added copyright + other non-functional changes --- distiller/model_transforms.py | 78 +++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py index cb2802b..dca01f4 100644 --- a/distiller/model_transforms.py +++ b/distiller/model_transforms.py @@ -1,7 +1,22 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import torch import torch.nn as nn from collections import OrderedDict - import distiller import distiller.modules from distiller.quantization.sim_bn_fold import SimulatedFoldedBatchNorm @@ -9,34 +24,7 @@ import logging msglogger = logging.getLogger() -def _fuse_sequence(sequence, named_modules, fuse_fn): - names = [m.distiller_name for m in sequence] - msglogger.debug('Fusing sequence {}'.format(names)) - - # Call fusing function - fused_module = fuse_fn(sequence) - if fused_module is None: - msglogger.debug('Sequence {} was not fused'.format(names)) - return - - # Leave a 'mark' in the fused module, indicating which modules were fused. This can come in handy - # post-fusing, since the identity nodes don't show up in SummrayGraph (they're optimized away). - setattr(sequence[0], 'fused_modules', names[1:]) - - # Replace the first module in the sequence with the fused module - def split_name(name): - if '.' in name: - return name.rsplit('.', 1) - else: - return '', name - container_name, root_module = split_name(names[0]) - container = named_modules[container_name] - setattr(container, root_module, fused_module) - - # Replace the rest of the models in the sequence with identity ops - for container_name, sub_module_name in map(lambda name: split_name(name), names[1:]): - container = named_modules[container_name] - setattr(container, sub_module_name, nn.Identity()) +__all__ = ["fuse_modules", "fold_batch_norms_inference"] def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None): @@ -96,7 +84,7 @@ def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map _fuse_sequence(curr_sequence, named_modules, fuse_fn) reset = True elif len(adj_entry.successors) > 1: - msglogger.debug(node_name + " is connected to multiple outputs, not fuse-able") + msglogger.debug(node_name + " is connected to multiple outputs, not fusible") reset = True elif isinstance(module, types_sequence[0]): # Current module breaks the current sequence, check if it's the start of a new sequence @@ -139,3 +127,33 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None): foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map) + + +def _fuse_sequence(sequence, named_modules, fuse_fn): + names = [m.distiller_name for m in sequence] + msglogger.debug('Fusing sequence {}'.format(names)) + + # Call fusing function + fused_module = fuse_fn(sequence) + if fused_module is None: + msglogger.error('Sequence {} was not fused'.format(names)) + return + + # Leave a 'mark' in the fused module, indicating which modules were fused. This can come in handy + # post-fusion, since the identity nodes don't show up in SummaryGraph (they're optimized away). + setattr(sequence[0], 'fused_modules', names[1:]) + + # Replace the first module in the sequence with the fused module + def split_name(name): + if '.' in name: + return name.rsplit('.', 1) + else: + return '', name + container_name, root_module = split_name(names[0]) + container = named_modules[container_name] + setattr(container, root_module, fused_module) + + # Replace the rest of the modules in the sequence with identity ops + for container_name, sub_module_name in map(lambda name: split_name(name), names[1:]): + container = named_modules[container_name] + setattr(container, sub_module_name, nn.Identity()) -- GitLab