From ef3e7415625f3aeb8300c4147b0dd551fd2decf7 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 19 Aug 2019 10:09:20 +0300
Subject: [PATCH] model_transforms.py: added copyright + other non-functional
 changes

---
 distiller/model_transforms.py | 78 +++++++++++++++++++++--------------
 1 file changed, 48 insertions(+), 30 deletions(-)

diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py
index cb2802b..dca01f4 100644
--- a/distiller/model_transforms.py
+++ b/distiller/model_transforms.py
@@ -1,7 +1,22 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
 import torch
 import torch.nn as nn
 from collections import OrderedDict
-
 import distiller
 import distiller.modules
 from distiller.quantization.sim_bn_fold import SimulatedFoldedBatchNorm
@@ -9,34 +24,7 @@ import logging
 msglogger = logging.getLogger()
 
 
-def _fuse_sequence(sequence, named_modules, fuse_fn):
-    names = [m.distiller_name for m in sequence]
-    msglogger.debug('Fusing sequence {}'.format(names))
-
-    # Call fusing function
-    fused_module = fuse_fn(sequence)
-    if fused_module is None:
-        msglogger.debug('Sequence {} was not fused'.format(names))
-        return
-
-    # Leave a 'mark' in the fused module, indicating which modules were fused. This can come in handy
-    # post-fusing, since the identity nodes don't show up in SummrayGraph (they're optimized away).
-    setattr(sequence[0], 'fused_modules', names[1:])
-
-    # Replace the first module in the sequence with the fused module
-    def split_name(name):
-        if '.' in name:
-            return name.rsplit('.', 1)
-        else:
-            return '', name
-    container_name, root_module = split_name(names[0])
-    container = named_modules[container_name]
-    setattr(container, root_module, fused_module)
-
-    # Replace the rest of the models in the sequence with identity ops
-    for container_name, sub_module_name in map(lambda name: split_name(name), names[1:]):
-        container = named_modules[container_name]
-        setattr(container, sub_module_name, nn.Identity())
+__all__ = ["fuse_modules", "fold_batch_norms_inference"]
 
 
 def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None):
@@ -96,7 +84,7 @@ def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map
                     _fuse_sequence(curr_sequence, named_modules, fuse_fn)
                     reset = True
                 elif len(adj_entry.successors) > 1:
-                    msglogger.debug(node_name + " is connected to multiple outputs, not fuse-able")
+                    msglogger.debug(node_name + " is connected to multiple outputs, not fusible")
                     reset = True
             elif isinstance(module, types_sequence[0]):
                 # Current module breaks the current sequence, check if it's the start of a new sequence
@@ -139,3 +127,33 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None):
     foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
     batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
     return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map)
+
+
+def _fuse_sequence(sequence, named_modules, fuse_fn):
+    names = [m.distiller_name for m in sequence]
+    msglogger.debug('Fusing sequence {}'.format(names))
+
+    # Call fusing function
+    fused_module = fuse_fn(sequence)
+    if fused_module is None:
+        msglogger.error('Sequence {} was not fused'.format(names))
+        return
+
+    # Leave a 'mark' in the fused module, indicating which modules were fused. This can come in handy
+    # post-fusion, since the identity nodes don't show up in SummaryGraph (they're optimized away).
+    setattr(sequence[0], 'fused_modules', names[1:])
+
+    # Replace the first module in the sequence with the fused module
+    def split_name(name):
+        if '.' in name:
+            return name.rsplit('.', 1)
+        else:
+            return '', name
+    container_name, root_module = split_name(names[0])
+    container = named_modules[container_name]
+    setattr(container, root_module, fused_module)
+
+    # Replace the rest of the modules in the sequence with identity ops
+    for container_name, sub_module_name in map(lambda name: split_name(name), names[1:]):
+        container = named_modules[container_name]
+        setattr(container, sub_module_name, nn.Identity())
-- 
GitLab