diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index fcbd9342e833cba0829cda581a4a41c01e8ef9ca..8ede4b25b34207d68c542889c61c436be96d2b60 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -281,28 +281,28 @@ class Quantizer(object):
                 # We indicate this module wasn't replaced by a wrapper
                 replace_msg(full_name)
                 self.modules_processed[module] = full_name, None
-                continue
-
-            # We use a type hint comment to let IDEs know replace_fn is a function
-            replace_fn = self.replacement_factory[type(module)]  # type: Optional[Callable]
-            # If the replacement function wasn't specified - continue without replacing this module.
-            if replace_fn is not None:
-                valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn)
-                if invalid_kwargs:
-                    raise TypeError("""Quantizer of type %s doesn't accept \"%s\" 
-                                        as override arguments for %s. Allowed kwargs: %s"""
-                                    % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
-                new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
-                replace_msg(full_name, (module, new_module))
-                # Add to history of prepared submodules
-                self.modules_processed[module] = full_name, new_module
-                setattr(container, name, new_module)
-
-                # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
-                if not distiller.has_children(module) and distiller.has_children(new_module):
-                    for sub_module_name, sub_module in new_module.named_modules():
-                        self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
-                    self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
+            else:
+                # We use a type hint comment to let IDEs know replace_fn is a function
+                replace_fn = self.replacement_factory[type(module)]  # type: Optional[Callable]
+                # If the replacement function wasn't specified - continue without replacing this module.
+                if replace_fn is not None:
+                    valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name],
+                                                                           replace_fn)
+                    if invalid_kwargs:
+                        raise TypeError("""Quantizer of type %s doesn't accept \"%s\" 
+                                            as override arguments for %s. Allowed kwargs: %s"""
+                                        % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
+                    new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
+                    replace_msg(full_name, (module, new_module))
+                    # Add to history of prepared submodules
+                    self.modules_processed[module] = full_name, new_module
+                    setattr(container, name, new_module)
+
+                    # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
+                    if not distiller.has_children(module) and distiller.has_children(new_module):
+                        for sub_module_name, sub_module in new_module.named_modules():
+                            self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
+                        self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
 
             if distiller.has_children(module):
                 # For container we call recursively
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index eecfa3ffe3b3065b607a7c0eac1675717701a066..049e563eda6d0c8034a9414af0bef43b58181716 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -147,30 +147,34 @@ class DummyQuantizer(Quantizer):
 #############################
 # Other utils
 #############################
-
-expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
-
-
 def params_quantizable(module):
     return isinstance(module, (nn.Conv2d, nn.Linear))
 
 
 def get_expected_qbits(model, qbits, expected_overrides):
-    expected_qbits = {}
-    post_prepare_changes = {}
+    expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
+
+    expected_qbits = OrderedDict()
+    post_prepare_qbbits_changes = OrderedDict()
+    post_prepare_expected_types = OrderedDict()
     prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else ''
     for orig_name, orig_module in model.named_modules():
+        orig_module_type = type(orig_module)
         bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits)
         if not params_quantizable(orig_module):
             bits_w = bits_b = None
         expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)
+        if expected_qbits[orig_name] == QBits(None, None, None):
+            post_prepare_expected_types[orig_name] = orig_module_type
+        else:
+            post_prepare_expected_types[orig_name] = expected_type_replacements.get(orig_module_type, orig_module_type)
+            # We're testing replacement of module with container
+            if post_prepare_expected_types[orig_name] == DummyWrapperLayer:
+                post_prepare_qbbits_changes[orig_name] = QBits(bits_a, None, None)
+                post_prepare_qbbits_changes[orig_name + '.inner'] = expected_qbits[orig_name]
+                post_prepare_expected_types[orig_name + '.inner'] = orig_module_type
 
-        # We're testing replacement of module with container
-        if isinstance(orig_module, (nn.Conv2d, nn.Linear)):
-            post_prepare_changes[orig_name] = QBits(bits_a, None, None)
-            post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name]
-
-    return expected_qbits, post_prepare_changes
+    return expected_qbits, post_prepare_qbbits_changes, post_prepare_expected_types
 
 
 #############################
@@ -251,7 +255,10 @@ bias_key = 'bits_bias'
           'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}),
         (QBits(8, 4, 32),
          OrderedDict([('conv1', {acts_key: 8, wts_key: 4, bias_key: None})]),
-         {'conv1': QBits(8, 4, None)})
+         {'conv1': QBits(8, 4, None)}),
+        (QBits(None, 8, 32),
+         OrderedDict([('conv1', {acts_key: 8, wts_key: 8, bias_key: 32})]),
+         {'conv1': QBits(8, 8, 32)})
     ],
     ids=[
         'no_override',
@@ -260,7 +267,8 @@ bias_key = 'bits_bias'
         'overlap_pattern_override_proper',  # "proper" ==> Specific pattern before broader pattern
         'overlap_pattern_override_wrong',   # "wrong" ==> Broad pattern before specific pattern, so specific pattern
                                             #             never actually matched
-        'wts_quant_bias_not'
+        'wts_quant_bias_not',
+        'dont_quant_acts'
     ]
 )
 def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overrides,
@@ -270,7 +278,9 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
     m_orig = deepcopy(model)
 
     # Build expected QBits
-    expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
+    expected_qbits, post_prepare_changes, post_prepare_expected_types = get_expected_qbits(model,
+                                                                                           qbits,
+                                                                                           explicit_expected_overrides)
 
     # Initialize Quantizer
     q = DummyQuantizer(model, optimizer=optimizer,
@@ -317,15 +327,12 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
 
         # Check module replacement is as expected
         q_module = q_named_modules[orig_name]
-        expected_type = expected_type_replacements.get(type(orig_module))
-        if expected_type is None or expected_qbits[orig_name] == QBits(None, None, None):
-            assert type(orig_module) == type(q_module)
-        else:
-            assert type(q_module) == expected_type
-            if expected_type == DummyWrapperLayer:
-                assert expected_qbits[orig_name + '.inner'] == q_module.qbits
-            else:
-                assert expected_qbits[orig_name] == q_module.qbits
+        expected_type = post_prepare_expected_types[orig_name]
+        assert type(q_module) == expected_type
+        if expected_type == DummyWrapperLayer:
+            assert expected_qbits[orig_name + '.inner'] == q_module.qbits
+        elif expected_type == DummyQuantLayer:
+            assert expected_qbits[orig_name] == q_module.qbits
 
 
 @pytest.mark.parametrize(
@@ -344,7 +351,7 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
 def test_param_quantization(model, optimizer, qbits, overrides, explicit_expected_overrides,
                             train_with_fp_copy):
     # Build expected QBits
-    expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
+    expected_qbits, post_prepare_changes, _ = get_expected_qbits(model, qbits, explicit_expected_overrides)
 
     q = DummyQuantizer(model, optimizer=optimizer,
                        bits_activations=qbits.acts, bits_weights=qbits.wts, bits_bias=qbits.bias,