diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py index 7e14f4e1b95ba454ea4c0c2fd01c132ff6eda976..622a47073f747ecd5fcd84589312ca5706ce09ed 100644 --- a/distiller/modules/rnn.py +++ b/distiller/modules/rnn.py @@ -201,8 +201,8 @@ class DistillerLSTM(nn.Module): # # Process each timestep at the entire layers chain - # # each timestep is forwarded through `front` and `back` chains independently, # # similarily to a unidirectional LSTM. - # self.cells = self._create_cells_list('cell', 1) - # self.cells_reverse = self._create_cells_list('cell_reverse', 2) + # self.cells = self._create_cells_list(1) + # self.cells_reverse = self._create_cells_list(2) # self.forward_fn = self.process_layer_wise # self.layer_chain_fn = self._layer_chain_bidirectional_type1 @@ -210,36 +210,25 @@ class DistillerLSTM(nn.Module): # Process the entire sequence at each layer consecutively - # the output of one layer is the sequence processed through the `front` and `back` cells # and the input to the next layers are both `output_front` and `output_back`. - self.cells = self._create_cells_list('cell', 2) - self.cells_reverse = self._create_cells_list('cell_reverse', 2) + self.cells = self._create_cells_list(2) + self.cells_reverse = self._create_cells_list(2) self.forward_fn = self._bidirectional_type2_forward else: raise ValueError("The only allowed types are [1, 2].") else: - self.cells = self._create_cells_list('cell') + self.cells = self._create_cells_list() self.forward_fn = self.process_layer_wise self.layer_chain_fn = self._layer_chain_unidirectional self.dropout = nn.Dropout(dropout) self.dropout_factor = dropout - def _create_cells_list(self, name, hidden_size_scale=1): - # We don't use a ModuleList, because they don't show up properly as scope names when creating a trace. - # That makes it impossible to map back from the trace to the actual module, which in turn means that - # mechanisms that rely on understanding modules connectivity won't work (such as fusions in post-training - # quantization). - # - # So, we register each cell manually and just store them in a vanilla list - + def _create_cells_list(self, hidden_size_scale=1): # We always have the first layer - c = DistillerLSTMCell(self.input_size, self.hidden_size, self.bias) - setattr(self, name + '_0', c) - cells = [c] + cells = nn.ModuleList([DistillerLSTMCell(self.input_size, self.hidden_size, self.bias)]) for i in range(1, self.num_layers): - c = DistillerLSTMCell(hidden_size_scale * self.hidden_size, self.hidden_size, self.bias) - setattr(self, '{}_{}'.format(name, i), c) - cells.append(c) + cells.append(DistillerLSTMCell(hidden_size_scale * self.hidden_size, self.hidden_size, self.bias)) return cells def forward(self, x, h=None): diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index bea53e5f901b8bedf6b6c05f05bedbe038850108..7fb5ab5b712b82426df714f46a00d6ca9bfaa96d 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -19,9 +19,11 @@ import re import numpy as np import collections import torch +import torch.nn as nn import torch.jit as jit import logging from collections import OrderedDict, defaultdict +from collections.abc import MutableSequence, Iterable msglogger = logging.getLogger() @@ -71,6 +73,11 @@ class SummaryGraph(object): def __init__(self, model, dummy_input, apply_scope_name_workarounds=True): self._src_model = model model_clone = distiller.make_non_parallel_copy(model) + + # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList + # See documentation of _DistillerModuleList class for details on why this is done + model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone) + with torch.onnx.set_training(model_clone, False): device = distiller.model_device(model_clone) @@ -142,6 +149,10 @@ class SummaryGraph(object): # Convert the graph node's scope name to a PyTorch module name module_name = onnx_name_2_pytorch_name(new_op['orig-name']) + + # Get name from before conversion to DistillerModuleList + module_name = converted_module_names_map[module_name] + if len(module_name) == 0: # Special case where the module name is an empty string - this happens # when the op is called from the "top-level" of the model @@ -561,3 +572,169 @@ class AdjacentsEntry(object): return self.op_meta == other.op_meta and \ self.predecessors == other.predecessors and \ self.successors == other.successors + + +class _DistillerModuleList(object): + r"""A almost-drop-in replacement for torch.nn.ModuleList that results in full and unique scope-names when traced + + So why do we need this? + Some flows in Distiller, such as modules fusion and "net-aware" quantization in PostTrainLinearQuantizer, rely + on the ability to infer the connectivity within the model, at the Python API level. This is done using + SummaryGraph, which internally uses PyTorch's trace capabilities. When tracing, each operation + executed creates a node in the trace, which has a "scope-name". Distiller then uses the "scope-name" to do a + reverse mapping - map from the trace node back to the actual nn.Module defined in the model code. + + These "scope-names" are generated by tracking the ".forward()" calls of modules. However, The torch.nn.ModuleList + class itself doesn't have its own forward method. That makes perfect sense - it is only intended to be used as a + container of modules which the user accesses explicitly. + Unfortunately, this means that if an operation is part of a ModuleList, the name of the ModuleList instance + does not appear in the "scope-name". This makes it impossible for us to do the reverse mapping mentioned + above. + + From here on, we refer to the module which contains the DistillerModuleList instance as the "parent module". + + Similarities to torch.nn.ModuleList: + * A DistillerModuleList can be indexed like a regular Python list, but the modules it contains are properly + registered and will be visible to all torch.nn.Module methods. + * The DistllerModuleList instance is registered as an attribute of the "parent module" + * This means that in terms of accessing the modules and invoking them, DistillerModuleList behaves exactly the + same as torch.nn.ModuleList. See the example below. + + Differences vs. torch.nn.ModuleList: + * DistillerModuleList is NOT a sub-class of torch.nn.Module + * This means that the modules in the list are NOT sub-modules of the list itself. They are registered as + sub-modules of the "parent module". That is - the contents of a DistillerModuleList are "flattened" within the + "parent module". + * In addition, we can't use the '.' character to denote the "nesting" of a module within the list. We use '_'. + * All of this means that calls to functions like state_dict() / named_modules() / named_children() / etc. on the + "parent_module" return different results when this class is used compared to torch.nn.ModuleList. + + At the moment we don't see a usage for this class "in the wild", outside of SummaryGraph generation. + In the context of SummaryGraph, we're going to take a pre-created model and replace any torch.nn.ModuleList + instances with DistillerModuleLists. Once that is done, during model execution we expect that lists are being + used as read-only (no modules are added to/removed from the list). We're not supporting loading state_dict "across" + converted models. + This means that: + * We implement only a subset of the standard API of a Python sequence (see collections.abc.MutableSequence): + 'append()', 'extend()', '__len__()' and '__getitem()_' + These are the only ones required to perform the conversion for an already created model. + * We're not implementing: + 'insert()', '__setitem__()' and '__delitem__()'. + + If we see in the future that our assumptions break, we'll add the necessary APIs. + + For all the reasons mentioned above, and to avoid unnecessary confusion for users, we're keeping this class + internal to summary_graph for now. + + Args: + name (string): The base name to be used when registering modules added to the list + parent_module (torch.nn.Module): The module to which the modules added to the list will be registered. + NOTE: This is expected to be the module containing the list, but we can't enforce this. + modules (iterable, optional): An iterable of modules to initialize the list with + """ + def __init__(self, name, parent_module, modules=None): + self.name = name + if not isinstance(parent_module, nn.Module): + raise TypeError('parent_module must be an instance of torch.nn.Module') + self.parent_module = parent_module + self._modules = [] + if modules is not None: + self.extend(modules) + + def _name_for_idx(self, idx): + return self.name + '_' + str(idx) + + def _verify_on_insertion(self, module, idx): + if isinstance(module, nn.ModuleList): + module = _DistillerModuleList(self._name_for_idx(idx), self.parent_module, module) + if isinstance(module, _DistillerModuleList): + if module.parent_module != self.parent_module: + raise ValueError("When nesting one DistillerModuleList within another, both must have the same " + "'parent_module'") + return module + + def __getitem__(self, idx): + return self._modules[idx] + + def __len__(self): + return len(self._modules) + + def append(self, module): + module = self._verify_on_insertion(module, len(self)) + if not isinstance(module, _DistillerModuleList): + self.parent_module.add_module(self._name_for_idx(len(self)), module) + self._modules.append(module) + + def extend(self, modules): + if not isinstance(modules, Iterable): + raise TypeError('DistillerModuleList.extend must be called with an iterable, but got ' + + modules.__class__.__name__) + for module in modules: + self.append(module) + + def named_modules(self, memo=None, prefix=''): + if memo is None: + memo = set() + if self not in memo: + memo.add(self) + # yield prefix, self + for idx, module in enumerate(self._modules): + if module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + str(idx) + for m in module.named_modules(memo, submodule_prefix): + yield m + + def modules(self): + for _, module in self.named_modules(): + yield module + + def __repr__(self): + # A simplified version of torch.nn.Module.__repr__ + from torch.nn.modules.module import _addindent + + child_lines = [] + for idx, module in enumerate(self._modules): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + str(idx) + '): ' + mod_str) + + main_str = self.__class__.__name__ + '(' + if child_lines: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + main_str += ')' + return main_str + + +def _to_distiller_modulelist(model): + """Replaces all instances of torch.nn.ModuleList in a model with DistillerModuleList instances + + Args: + model (torch.nn.Module): Model to convert + """ + def convert_container(container): + named_children = OrderedDict(container.named_children()) + # To maintain a similar order of registered modules compared to the original container, we unregister + # all modules and then register them again + for n, _ in named_children.items(): + delattr(container, n) + for name, child in named_children.items(): + if isinstance(child, nn.ModuleList): + child = _DistillerModuleList(name, container, child) + to_check = child.modules() + else: + to_check = [child] + setattr(container, name, child) + for m in to_check: + if isinstance(m, _DistillerModuleList): + continue + if distiller.has_children(m): + convert_container(m) + return container + + named_modules_orig = OrderedDict([(n, m) for n, m in model.named_modules() if not isinstance(m, nn.ModuleList)]) + model = convert_container(model) + named_modules_dmlist = OrderedDict(model.named_modules()) + converted_module_names_map = OrderedDict(zip(named_modules_dmlist.keys(), named_modules_orig.keys())) + + return model, converted_module_names_map diff --git a/examples/word_language_model/manual_lstm_pretrained_stats.yaml b/examples/word_language_model/manual_lstm_pretrained_stats.yaml index 086d909a29a7db963d36f1dcd3547d84587d187e..34baf1601ad82b4d40717d3a4f7dbb05a30ecee7 100644 --- a/examples/word_language_model/manual_lstm_pretrained_stats.yaml +++ b/examples/word_language_model/manual_lstm_pretrained_stats.yaml @@ -16,7 +16,7 @@ encoder: mean: 0.0011389567903235516 std: 0.106637374597835 shape: (35, 10, 1500) -rnn.cell_0.fc_gate_x: +rnn.cells.0.fc_gate_x: inputs: 0: min: -1.0055707693099976 @@ -34,7 +34,7 @@ rnn.cell_0.fc_gate_x: mean: -0.5324024169690869 std: 1.136934240306631 shape: (10, 6000) -rnn.cell_0.fc_gate_h: +rnn.cells.0.fc_gate_h: inputs: 0: min: -0.9941717386245728 @@ -52,7 +52,7 @@ rnn.cell_0.fc_gate_h: mean: -0.24742181252201967 std: 0.6138123563333803 shape: (10, 6000) -rnn.cell_0.eltwiseadd_gate: +rnn.cells.0.eltwiseadd_gate: inputs: 0: min: -7.170577049255371 @@ -78,7 +78,7 @@ rnn.cell_0.eltwiseadd_gate: mean: -0.7798242293363614 std: 1.3385719958875721 shape: (10, 6000) -rnn.cell_0.act_f: +rnn.cells.0.act_f: inputs: 0: min: -15.612003326416016 @@ -96,7 +96,7 @@ rnn.cell_0.act_f: mean: 0.31776015875831237 std: 0.18128372322608863 shape: (10, 1500) -rnn.cell_0.act_i: +rnn.cells.0.act_i: inputs: 0: min: -12.639559745788574 @@ -114,7 +114,7 @@ rnn.cell_0.act_i: mean: 0.26847494655177206 std: 0.2133546549199177 shape: (10, 1500) -rnn.cell_0.act_o: +rnn.cells.0.act_o: inputs: 0: min: -9.940855979919434 @@ -132,7 +132,7 @@ rnn.cell_0.act_o: mean: 0.318497704179476 std: 0.19847815427571672 shape: (10, 1500) -rnn.cell_0.act_g: +rnn.cells.0.act_g: inputs: 0: min: -11.641252517700195 @@ -150,7 +150,7 @@ rnn.cell_0.act_g: mean: -6.53550379788874e-05 std: 0.7375218759945332 shape: (10, 1500) -rnn.cell_0.eltwisemult_cell_forget: +rnn.cells.0.eltwisemult_cell_forget: inputs: 0: min: 1.6587961226832704e-07 @@ -176,7 +176,7 @@ rnn.cell_0.eltwisemult_cell_forget: mean: -3.0967179330471856e-05 std: 0.13354148372150612 shape: (10, 1500) -rnn.cell_0.eltwisemult_cell_input: +rnn.cells.0.eltwisemult_cell_input: inputs: 0: min: 3.2412126529379748e-06 @@ -202,7 +202,7 @@ rnn.cell_0.eltwisemult_cell_input: mean: -0.0009125015388671264 std: 0.2702341313401001 shape: (10, 1500) -rnn.cell_0.eltwiseadd_cell: +rnn.cells.0.eltwiseadd_cell: inputs: 0: min: -18.10362434387207 @@ -228,7 +228,7 @@ rnn.cell_0.eltwiseadd_cell: mean: -0.0009434687136402496 std: 0.3030323653891493 shape: (10, 1500) -rnn.cell_0.act_h: +rnn.cells.0.act_h: inputs: 0: min: -18.4894962310791 @@ -246,7 +246,7 @@ rnn.cell_0.act_h: mean: -0.0007417835671436678 std: 0.26512230259603614 shape: (10, 1500) -rnn.cell_0.eltwisemult_hidden: +rnn.cells.0.eltwisemult_hidden: inputs: 0: min: 4.816373984795064e-05 @@ -272,7 +272,7 @@ rnn.cell_0.eltwisemult_hidden: mean: 0.00019907324964651022 std: 0.10292063063382816 shape: (10, 1500) -rnn.cell_1.fc_gate_x: +rnn.cells.1.fc_gate_x: inputs: 0: min: -0.9941717386245728 @@ -290,7 +290,7 @@ rnn.cell_1.fc_gate_x: mean: -0.2753187046971623 std: 1.2196254520335272 shape: (10, 6000) -rnn.cell_1.fc_gate_h: +rnn.cells.1.fc_gate_h: inputs: 0: min: -0.9998947978019714 @@ -308,7 +308,7 @@ rnn.cell_1.fc_gate_h: mean: -0.2947964010912462 std: 1.2404614338320914 shape: (10, 6000) -rnn.cell_1.eltwiseadd_gate: +rnn.cells.1.eltwiseadd_gate: inputs: 0: min: -14.241074562072754 @@ -334,7 +334,7 @@ rnn.cell_1.eltwiseadd_gate: mean: -0.5701151059095978 std: 1.9359252436683256 shape: (10, 6000) -rnn.cell_1.act_f: +rnn.cells.1.act_f: inputs: 0: min: -9.876046180725098 @@ -352,7 +352,7 @@ rnn.cell_1.act_f: mean: 0.41393780546997105 std: 0.25393672173390425 shape: (10, 1500) -rnn.cell_1.act_i: +rnn.cells.1.act_i: inputs: 0: min: -16.290634155273438 @@ -370,7 +370,7 @@ rnn.cell_1.act_i: mean: 0.2965666530026901 std: 0.2653741353156098 shape: (10, 1500) -rnn.cell_1.act_o: +rnn.cells.1.act_o: inputs: 0: min: -16.701549530029297 @@ -388,7 +388,7 @@ rnn.cell_1.act_o: mean: 0.3911302530009079 std: 0.30228075066032956 shape: (10, 1500) -rnn.cell_1.act_g: +rnn.cells.1.act_g: inputs: 0: min: -18.20545196533203 @@ -406,7 +406,7 @@ rnn.cell_1.act_g: mean: 0.006299477315604894 std: 0.8122228369780807 shape: (10, 1500) -rnn.cell_1.eltwisemult_cell_forget: +rnn.cells.1.eltwisemult_cell_forget: inputs: 0: min: 5.1388426072662696e-05 @@ -432,7 +432,7 @@ rnn.cell_1.eltwisemult_cell_forget: mean: 0.004142155777030078 std: 0.5223340602367401 shape: (10, 1500) -rnn.cell_1.eltwisemult_cell_input: +rnn.cells.1.eltwisemult_cell_input: inputs: 0: min: 8.415258179184093e-08 @@ -458,7 +458,7 @@ rnn.cell_1.eltwisemult_cell_input: mean: 0.0014335625686553053 std: 0.3403960596000532 shape: (10, 1500) -rnn.cell_1.eltwiseadd_cell: +rnn.cells.1.eltwiseadd_cell: inputs: 0: min: -48.25635528564453 @@ -484,7 +484,7 @@ rnn.cell_1.eltwiseadd_cell: mean: 0.0055757183429154394 std: 0.6651466271437165 shape: (10, 1500) -rnn.cell_1.act_h: +rnn.cells.1.act_h: inputs: 0: min: -48.280601501464844 @@ -502,7 +502,7 @@ rnn.cell_1.act_h: mean: 0.0021621192841517175 std: 0.38166194375973717 shape: (10, 1500) -rnn.cell_1.eltwisemult_hidden: +rnn.cells.1.eltwisemult_hidden: inputs: 0: min: 5.579678230560603e-08 diff --git a/examples/word_language_model/quantize_lstm.ipynb b/examples/word_language_model/quantize_lstm.ipynb index 0e23390cc482f6c2e0b9e80382cf997e4a7c87a4..3014ba5794f4c7a384dbf41774eb8cea4b3c4206 100644 --- a/examples/word_language_model/quantize_lstm.ipynb +++ b/examples/word_language_model/quantize_lstm.ipynb @@ -426,10 +426,10 @@ "import pprint\n", "pp = pprint.PrettyPrinter(indent=1)\n", "print('Stats BEFORE prepare_model:')\n", - "pp.pprint(stats_before_prepare['rnn.cell_0.eltwiseadd_gate']['output'])\n", + "pp.pprint(stats_before_prepare['rnn.cells.0.eltwiseadd_gate']['output'])\n", "\n", "print('\\nStats AFTER to prepare_model:')\n", - "pp.pprint(quantizer.model_activation_stats['rnn.cell_0.eltwiseadd_gate']['output'])" + "pp.pprint(quantizer.model_activation_stats['rnn.cells.0.eltwiseadd_gate']['output'])" ] }, { @@ -517,8 +517,8 @@ } ], "source": [ - "print(quantizer.model.rnn.cell_0.fc_gate_x)\n", - "print(quantizer.model.rnn.cell_0.eltwiseadd_gate)" + "print(quantizer.model.rnn.cells[0].fc_gate_x)\n", + "print(quantizer.model.rnn.cells[0].eltwiseadd_gate)" ] }, { diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 9d932aac5fc2d7fdffced3ef6c10f268ae69a123..8bdc3fb69034ee53b0c531138561762db9912bb5 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -423,12 +423,12 @@ def rnn_model_stats(rnn_model): (None, ClipMode.AVG, 0), (distiller.utils.yaml_ordered_load(""" - rnn.cell_0.eltwisemult_hidden: + rnn.cells.0.eltwisemult_hidden: clip_acts: NONE """), ClipMode.NONE, 0), (distiller.utils.yaml_ordered_load(""" - rnn.cell_0.eltwisemult_hidden: + rnn.cells.0.eltwisemult_hidden: clip_acts: N_STD clip_n_stds: 2 """), ClipMode.N_STD, 2) diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 927ccbaec81c751a1c87ee2ef72f827db63462a6..e220c90484c5a6ebfe39610a883fa734416b6cbb 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -18,13 +18,15 @@ import logging import torch import torch.nn as nn import pytest +from copy import deepcopy +from collections import OrderedDict import distiller from distiller.models import ALL_MODEL_NAMES, create_model from distiller.apputils import * from distiller import normalize_module_name, denormalize_module_name, \ SummaryGraph, onnx_name_2_pytorch_name from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose -from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata +from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, _to_distiller_modulelist # Logging configuration logging.basicConfig(level=logging.DEBUG) @@ -324,12 +326,16 @@ def test_adjacency_map(parallel, dedicated_modules): super(TestModel, self).__init__() self.conv = nn.Conv2d(3, 10, 5) self.bn = nn.BatchNorm2d(10) - self.relu = nn.ReLU() + self.post_conv_bn = nn.ModuleList([ + nn.Tanh(), + nn.ReLU() + ]) def forward(self, x): res = self.conv(x) y = self.bn(res) - y = self.relu(y) + for m in self.post_conv_bn: + y = m(y) return y + res def check_adj_entry(actual, expected): @@ -346,13 +352,14 @@ def test_adjacency_map(parallel, dedicated_modules): adj_map = sg.adjacency_map(dedicated_modules_only=dedicated_modules) if dedicated_modules: - assert len(adj_map) == 3 - else: assert len(adj_map) == 4 + else: + assert len(adj_map) == 5 conv_op_meta = OpSimpleMetadata(prefix + 'conv', 'Conv') bn_op_meta = OpSimpleMetadata(prefix + 'bn', 'BatchNormalization') - relu_op_meta = OpSimpleMetadata(prefix + 'relu', 'Relu') + tanh_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.0', 'Tanh') + relu_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.1', 'Relu') add_op_meta = OpSimpleMetadata('top_level_op', 'Add') name = conv_op_meta.name @@ -365,13 +372,20 @@ def test_adjacency_map(parallel, dedicated_modules): assert name in adj_map expected = AdjacentsEntry(bn_op_meta) expected.predecessors = [conv_op_meta] + expected.successors = [tanh_op_meta] + check_adj_entry(adj_map[name], expected) + + name = tanh_op_meta.name + assert name in adj_map + expected = AdjacentsEntry(tanh_op_meta) + expected.predecessors = [bn_op_meta] expected.successors = [relu_op_meta] check_adj_entry(adj_map[name], expected) name = relu_op_meta.name assert name in adj_map expected = AdjacentsEntry(relu_op_meta) - expected.predecessors = [bn_op_meta] + expected.predecessors = [tanh_op_meta] expected.successors = [] if dedicated_modules else [add_op_meta] check_adj_entry(adj_map[name], expected) @@ -385,6 +399,158 @@ def test_adjacency_map(parallel, dedicated_modules): check_adj_entry(adj_map[name], expected) +def test_distiller_module_list(): + ############################################################# + # Model for testing conversion of nested ModuleLists + ############################################################# + class ListsModule(nn.Module): + def __init__(self): + super(ListsModule, self).__init__() + self.conv1 = nn.Conv2d(3, 10, 5) + self.post_conv1 = nn.ModuleList([ + nn.BatchNorm2d(10), + nn.ModuleList([ + nn.ReLU(), + nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])]) + self.conv2 = nn.Conv2d(10, 20, 3) + self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)]) + + self.expected_mlist_to_dmlist = OrderedDict([ + ('post_conv1', ['post_conv1']), + ('post_conv1.1', ['post_conv1', '1']), + ('post_conv1.1.1', ['post_conv1', '1', '1']), + ('post_conv2', ['post_conv2']), + ]) + self.expected_list_contents_name_changes = OrderedDict([ + ('post_conv1.0', 'post_conv1_0'), + ('post_conv1.1.0', 'post_conv1_1_0'), + ('post_conv1.1.1.0', 'post_conv1_1_1_0'), + ('post_conv1.1.1.1', 'post_conv1_1_1_1'), + ('post_conv2.0', 'post_conv2_0'), + ('post_conv2.1', 'post_conv2_1'), + ]) + + def forward(self, x): + x = self.conv1(x) + x = self.post_conv1[0](x) + x = self.post_conv1[1][0](x) + for m in self.post_conv1[1][1]: + x = m(x) + x = self.conv2(x) + for m in self.post_conv2: + x = m(x) + return x + + ############################################################# + # Model for testing conversion in case of nested containers + ############################################################# + class Block(nn.Module): + def __init__(self, in_ch): + super(Block, self).__init__() + self.in_ch = in_ch + self.out_ch = in_ch * 2 + self.conv = nn.Conv2d(in_ch, self.out_ch, 3) + self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()]) + + def forward(self, x): + x = self.conv(x) + for m in self.post_conv: + x = m(x) + return x + + class BlocksModule(nn.Module): + def __init__(self): + super(BlocksModule, self).__init__() + self.block1 = Block(3) + self.blocks2_3 = nn.Sequential(Block(6), Block(12)) + self.blocks4_5 = nn.ModuleList([Block(24), Block(48)]) + self.block6 = Block(96) + + self.expected_mlist_to_dmlist = OrderedDict([ + ('block1.post_conv', ['block1', 'post_conv']), + ('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']), + ('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']), + ('blocks4_5', ['blocks4_5']), + ('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']), + ('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']), + ('block6.post_conv', ['block6', 'post_conv']), + ]) + self.expected_list_contents_name_changes = OrderedDict([ + ('block1.post_conv.0', 'block1.post_conv_0'), + ('block1.post_conv.1', 'block1.post_conv_1'), + ('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'), + ('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'), + ('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'), + ('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'), + ('blocks4_5.0', 'blocks4_5_0'), + ('blocks4_5.0.conv', 'blocks4_5_0.conv'), + ('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'), + ('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'), + ('blocks4_5.1', 'blocks4_5_1'), + ('blocks4_5.1.conv', 'blocks4_5_1.conv'), + ('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'), + ('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'), + ('block6.post_conv.0', 'block6.post_conv_0'), + ('block6.post_conv.1', 'block6.post_conv_1'), + ]) + + def forward(self, x): + x = self.block1(x) + x = self.blocks2_3(x) + for block in self.blocks4_5: + x = block(x) + x = self.block6(x) + return x + + def check(m): + def check_equal_tensors(actual, expected): + assert (actual == expected).all().item() == 1 + + m_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(m)) + + # Check all modules converted as expected + named_modules_dmlist = OrderedDict(m_dml.named_modules()) + for name_orig, module_orig in m.named_modules(): + if name_orig in m.expected_mlist_to_dmlist: + # Check ModuleLists were converted to an attribute with the expected name, which is not + # registered as a module in the converted model + assert name_orig not in named_modules_dmlist + attr_dml = m_dml + for attr_name in m.expected_mlist_to_dmlist[name_orig]: + try: + attr_dml = attr_dml[int(attr_name)] + except ValueError: + attr_dml = getattr(attr_dml, attr_name) + assert isinstance(attr_dml, _DistillerModuleList) + else: + # Check module name changed as expected, and that the module type didn't change + expected_name_dml = m.expected_list_contents_name_changes.get(name_orig, name_orig) + assert expected_name_dml in named_modules_dmlist + assert expected_name_dml in converted_module_names_map + assert converted_module_names_map[expected_name_dml] == name_orig + assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig) + converted_module_names_map.pop(expected_name_dml) + named_modules_dmlist.pop(expected_name_dml) + + assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map' + assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules' + + # Now make sure all parameters and buffers didn't change + for p_orig, p_dml in zip(m.parameters(), m_dml.parameters()): + check_equal_tensors(p_dml, p_orig) + for b_orig, b_dml in zip(m.buffers(), m_dml.buffers()): + check_equal_tensors(b_dml, b_orig) + + # Check forward pass gives identical results + x = torch.randn(1, 3, 50, 50) + y_orig = m(x) + y_dml = m_dml(x) + check_equal_tensors(y_dml, y_orig) + + check(ListsModule()) + check(BlocksModule()) + + if __name__ == '__main__': #test_connectivity_summary() test_sg_macs() \ No newline at end of file