diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 8eee101ad4f535aa12f73732e557a1248d751db0..baa7f9c53cf4b58fe4e30ffe5debb6ed14c9087f 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -173,6 +173,9 @@ class Quantizer(object):
         self.train_with_fp_copy = train_with_fp_copy
         self.params_to_quantize = []
 
+        # A dictionary of replaced modules and their respective names.
+        self.modules_replaced = OrderedDict()
+
     def _add_qbits_entry(self, module_name, module_type, qbits):
         if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]:
             # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
@@ -184,6 +187,26 @@ class Quantizer(object):
         self.module_overrides_map[module_name] = entry
 
     def prepare_model(self):
+        """
+        Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
+        and overrides configuration provided to __init__(), and according to the replacement_factory as
+        defined by the Quantizer sub-class being used.
+
+        Note:
+            If multiple sub-modules within the model actually reference the same module, then that module
+            is replaced only once, according to the configuration (bit-width and/or overrides) of the
+            first encountered reference.
+            Toy Example - say a module is constructed using this bit of code:
+
+                shared_relu = nn.ReLU
+                self.relu1 = shared_relu
+                self.relu2 = shared_relu
+
+            When traversing the model, a replacement will be generated when 'self.relu1' is encountered.
+            Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced
+            with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
+            will be ignored. A warning message will be shown.
+        """
         self._prepare_model_impl()
 
         msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
@@ -226,6 +249,14 @@ class Quantizer(object):
         # Iterate through model, insert quantization functions as appropriate
         for name, module in container.named_children():
             full_name = prefix + name
+            if module in self.modules_replaced:
+                previous_name, previous_wrapper = self.modules_replaced[module]
+                warnings.warn("Module '{0}' references to same module as '{1}'."
+                              ' Replacing with reference the same wrapper.'.format(full_name, previous_name),
+                              UserWarning)
+                msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, previous_wrapper))
+                setattr(container, name, previous_wrapper)
+                continue
             current_qbits = self.module_qbits_map[full_name]
             if current_qbits.acts is None and current_qbits.wts is None:
                 if self.module_overrides_map[full_name]:
@@ -241,6 +272,8 @@ class Quantizer(object):
                 new_module = self.replacement_factory[type(module)](module, full_name,
                                                                     self.module_qbits_map, **valid_kwargs)
                 msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
+                # Add to history of prepared submodules
+                self.modules_replaced[module] = full_name, new_module
                 setattr(container, name, new_module)
 
                 # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index fdd47c77a87b04548de64cdb892f71b4e4139070..fea1fe09ce772b78b0d522d220d41ece5914f164 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -83,6 +83,33 @@ class DummyModel(nn.Sequential):
             p.data = torch.zeros_like(p)
 
 
+class DummyDenseWithRelu(nn.Module):
+    def __init__(self, input_size, output_size, relu=None):
+        super(DummyDenseWithRelu, self).__init__()
+        self.input_size = input_size
+        self.output_size = output_size
+        self.relu = relu or nn.ReLU()
+        self.linear = nn.Linear(input_size, output_size)
+
+    def forward(self, x):
+        return self.relu(self.linear(x))
+
+
+class DummyModelWithSharedSubmodule(nn.Module):
+    def __init__(self, input_size, hidden_size, output_size):
+        super(DummyModelWithSharedSubmodule, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.output_size = output_size
+        self.dense1 = DummyDenseWithRelu(input_size, hidden_size)
+        self.dense2 = DummyDenseWithRelu(hidden_size, output_size, self.dense1.relu)
+
+    def forward(self, x):
+        x = self.dense1(x)
+        x = self.dense2(x)
+        return x
+
+
 #############################
 # Dummy Quantizer
 #############################
@@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer):
 
         self.replacement_factory[nn.Conv2d] = _dummy_wrapper_layer
         self.replacement_factory[nn.ReLU] = _dummy_quant_layer
+        self.replacement_factory[nn.Linear] = _dummy_wrapper_layer
         self.param_quantization_fn = dummy_quantize_params
 
 
@@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer):
 # Other utils
 #############################
 
-expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer}
+expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
 
 
 def params_quantizable(module):
@@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides):
         expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)
 
         # We're testing replacement of module with container
-        if isinstance(orig_module, nn.Conv2d):
+        if isinstance(orig_module, (nn.Conv2d, nn.Linear)):
             post_prepare_changes[orig_name] = QBits(bits_a, None, None)
             post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name]
 
@@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy):
     q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
     q.prepare_model()
     assert model_copy.relu1.overridable_prop == 123
+
+
+def test_shared_submodule(optimizer, train_with_fp_copy):
+    with pytest.warns(UserWarning,
+                      match="Module '{0}' references to same module as '{1}'.".format('dense2.relu', 'dense1.relu')):
+        densenet = DummyModelWithSharedSubmodule(1024, 1024, 1000)
+        quantizer = DummyQuantizer(densenet,
+                                   bits_weights=8, bits_activations=8, bits_bias=32,
+                                   optimizer=optimizer,
+                                   train_with_fp_copy=train_with_fp_copy)
+        quantizer.prepare_model()
+        assert isinstance(quantizer.model.dense1.relu, DummyQuantLayer)
+        assert quantizer.model.dense1.relu == quantizer.model.dense2.relu