diff --git a/distiller/config.py b/distiller/config.py
index 0e2fd628f0650a5694e2b0ff2b368bb52de5afeb..2661a2cea1afcdb3d6e4aee4c181a832be3f6d28 100755
--- a/distiller/config.py
+++ b/distiller/config.py
@@ -56,7 +56,7 @@ def dict_config(model, optimizer, sched_dict):
 
     pruners = __factory('pruners', model, sched_dict)
     regularizers = __factory('regularizers', model, sched_dict)
-    quantizers = __factory('quantizers', model, sched_dict)
+    quantizers = __factory('quantizers', model, sched_dict, optimizer=optimizer)
     if len(quantizers) > 1:
         print("\nError: Multiple Quantizers not supported")
         exit(1)
@@ -92,12 +92,6 @@ def dict_config(model, optimizer, sched_dict):
                 quantizer = quantizers[instance_name]
                 policy = distiller.QuantizationPolicy(quantizer)
 
-                # Quantizers for training modify the models parameters, need to update the optimizer
-                if quantizer.train_with_fp_copy:
-                    optimizer_type = type(optimizer)
-                    new_optimizer = optimizer_type(model.parameters(), **optimizer.defaults)
-                    optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
-
             elif 'lr_scheduler' in policy_def:
                 # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain
                 # a quantization policy was initialized (if exists)
diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py
index e50edbfe02fef5ca6eaea7227dfe21fa20c175e9..e174005241b14f5f6f1f0fe97dcabb9c850489a6 100644
--- a/distiller/quantization/__init__.py
+++ b/distiller/quantization/__init__.py
@@ -16,7 +16,7 @@
 
 from .quantizer import Quantizer
 from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer
-from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer
+from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
 
 del quantizer
 del range_linear
diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index 5854ab8b935f80d41d7b0dd0ac83ea61cd486eb8..2312dec5414bec8556101d13901a8ed07226aea5 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -43,6 +43,34 @@ class LinearQuantizeSTE(torch.autograd.Function):
         return grad_output, None, None, None
 
 
+class LearnedClippedLinearQuantizeSTE(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, clip_val, num_bits, dequantize, inplace):
+        ctx.save_for_backward(input, clip_val)
+        if inplace:
+            ctx.mark_dirty(input)
+        scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val.data[0])
+        output = clamp(input, 0, clip_val.data[0], inplace)
+        output = linear_quantize(output, scale_factor, inplace)
+        if dequantize:
+            output = linear_dequantize(output, scale_factor, inplace)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, clip_val = ctx.saved_tensors
+        grad_input = grad_output.clone()
+        grad_input[input.le(0)] = 0
+        grad_input[input.ge(clip_val.data[0])] = 0
+
+        grad_alpha = grad_output.clone()
+        grad_alpha[input.lt(clip_val.data[0])] = 0
+        grad_alpha = grad_alpha.sum().expand_as(clip_val)
+
+        # Straight-through estimator for the scale factor calculation
+        return grad_input, grad_alpha, None, None, None
+
+
 class ClippedLinearQuantization(nn.Module):
     def __init__(self, num_bits, clip_val, dequantize=True, inplace=False):
         super(ClippedLinearQuantization, self).__init__()
@@ -63,6 +91,24 @@ class ClippedLinearQuantization(nn.Module):
                                                            inplace_str)
 
 
+class LearnedClippedLinearQuantization(nn.Module):
+    def __init__(self, num_bits, init_act_clip_val, dequantize=True, inplace=False):
+        super(LearnedClippedLinearQuantization, self).__init__()
+        self.num_bits = num_bits
+        self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]))
+        self.dequantize = dequantize
+        self.inplace = inplace
+
+    def forward(self, input):
+        input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, self.dequantize, self.inplace)
+        return input
+
+    def __repr__(self):
+        inplace_str = ', inplace' if self.inplace else ''
+        return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val,
+                                                           inplace_str)
+
+
 class WRPNQuantizer(Quantizer):
     """
     Quantizer using the WRPN quantization scheme, as defined in:
@@ -72,10 +118,11 @@ class WRPNQuantizer(Quantizer):
         1. This class does not take care of layer widening as described in the paper
         2. The paper defines special handling for 1-bit weights which isn't supported here yet
     """
-    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False):
-        super(WRPNQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
-                                            bits_overrides=bits_overrides, train_with_fp_copy=True,
-                                            quantize_bias=quantize_bias)
+    def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
+                 quantize_bias=False):
+        super(WRPNQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
+                                            bits_weights=bits_weights, bits_overrides=bits_overrides,
+                                            train_with_fp_copy=True, quantize_bias=quantize_bias)
 
         def wrpn_quantize_param(param_fp, num_bits):
             scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1)
@@ -94,6 +141,15 @@ class WRPNQuantizer(Quantizer):
         self.replacement_factory[nn.ReLU] = relu_replace_fn
 
 
+def dorefa_quantize_param(param_fp, num_bits):
+    scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
+    out = param_fp.tanh()
+    out = out / (2 * out.abs().max()) + 0.5
+    out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
+    out = 2 * out - 1
+    return out
+
+
 class DorefaQuantizer(Quantizer):
     """
     Quantizer using the DoReFa scheme, as defined in:
@@ -104,18 +160,11 @@ class DorefaQuantizer(Quantizer):
         1. Gradients quantization not supported yet
         2. The paper defines special handling for 1-bit weights which isn't supported here yet
     """
-    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False):
-        super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
-                                              bits_overrides=bits_overrides, train_with_fp_copy=True,
-                                              quantize_bias=quantize_bias)
-
-        def dorefa_quantize_param(param_fp, num_bits):
-            scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
-            out = param_fp.tanh()
-            out = out / (2 * out.abs().max()) + 0.5
-            out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
-            out = 2 * out - 1
-            return out
+    def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
+                 quantize_bias=False):
+        super(DorefaQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
+                                              bits_weights=bits_weights, bits_overrides=bits_overrides,
+                                              train_with_fp_copy=True, quantize_bias=quantize_bias)
 
         def relu_replace_fn(module, name, qbits_map):
             bits_acts = qbits_map[name].acts
@@ -126,3 +175,44 @@ class DorefaQuantizer(Quantizer):
         self.param_quantization_fn = dorefa_quantize_param
 
         self.replacement_factory[nn.ReLU] = relu_replace_fn
+
+
+class PACTQuantizer(Quantizer):
+    """
+    Quantizer using the PACT quantization scheme, as defined in:
+    Choi et al., PACT: Parameterized Clipping Activation for Quantized Neural Networks
+    (https://arxiv.org/abs/1805.06085)
+
+    Args:
+        act_clip_init_val (float): Initial clipping value for activations, referred to as "alpha" in the paper
+            (default: 8.0)
+        act_clip_decay (float): L2 penalty applied to the clipping values, referred to as "lambda_alpha" in the paper.
+            If None then the optimizer's default weight decay value is used (default: None)
+    """
+    def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
+                 quantize_bias=False, act_clip_init_val=8.0, act_clip_decay=None):
+        super(PACTQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
+                                            bits_weights=bits_weights, bits_overrides=bits_overrides,
+                                            train_with_fp_copy=True, quantize_bias=quantize_bias)
+
+        def relu_replace_fn(module, name, qbits_map):
+            bits_acts = qbits_map[name].acts
+            if bits_acts is None:
+                return module
+            return LearnedClippedLinearQuantization(bits_acts, act_clip_init_val, dequantize=True,
+                                                    inplace=module.inplace)
+
+        self.param_quantization_fn = dorefa_quantize_param
+
+        self.replacement_factory[nn.ReLU] = relu_replace_fn
+
+        self.act_clip_decay = act_clip_decay
+
+    # In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter
+    # We optimize this value separately from the main model parameters
+    def _get_updated_optimizer_params_groups(self):
+        base_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' not in name]}
+        clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]}
+        if self.act_clip_decay is not None:
+            clip_val_group['weight_decay'] = self.act_clip_decay
+        return [base_group, clip_val_group]
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 77ef217356432eb9da867525eafa510c60bfd8ee..0a5b976db78994627eca2c89a6bd24abe5794c02 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -58,6 +58,9 @@ class Quantizer(object):
 
     Args:
         model (torch.nn.Module): The model to be quantized
+        optimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is going
+            to perform changes to existing model parameters and/or add new ones.
+            Specifically, when train_with_fp_copy is True, this cannot be None.
         bits_activations/weights (int): Default number of bits to use when quantizing each tensor type.
             Value of None means do not quantize.
         bits_overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with
@@ -77,15 +80,18 @@ class Quantizer(object):
                 3.2 We also back-prop through the 'quantize' operation from step 1
             4. Update fp_weights with gradients calculated in step 3.2
     """
-    def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
+    def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
                  quantize_bias=False, train_with_fp_copy=False):
         if not isinstance(bits_overrides, OrderedDict):
             raise TypeError('bits_overrides must be an instance of collections.OrderedDict')
+        if train_with_fp_copy and optimizer is None:
+            raise ValueError('optimizer cannot be None when train_with_fp_copy is True')
 
         self.default_qbits = QBits(acts=bits_activations, wts=bits_weights)
         self.quantize_bias = quantize_bias
 
         self.model = model
+        self.optimizer = optimizer
 
         # Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model
         self.model.quantizer_metadata = {'type': type(self),
@@ -166,6 +172,12 @@ class Quantizer(object):
                 msglogger.info(
                     "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts))
 
+        # If an optimizer was passed, assume we need to update it
+        if self.optimizer:
+            optimizer_type = type(self.optimizer)
+            new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
+            self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
+
         msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
 
     def _pre_process_container(self, container, prefix=''):
@@ -192,6 +204,20 @@ class Quantizer(object):
                 # For container we call recursively
                 self._pre_process_container(module, full_name + '.')
 
+    def _get_updated_optimizer_params_groups(self):
+        """
+        Returns a list of model parameter groups and optimizer hyper-parameter overrides,
+        as expected by the __init__ function of torch.optim.Optimizer.
+        This is called after all model changes were made in prepare_model, in case an Optimizer instance was
+        passed to __init__.
+
+        Subclasses which add parameters to the model should override as needed.
+
+        :return: List of parameter groups
+        """
+        # Default implementation - just return all model parameters as one group
+        return [{'params': self.model.parameters()}]
+
     def quantize_params(self):
         """
         Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index ff7fd9d7e6d8b8102cce63858acbe02388819f9f..fe4e4637d2845d73772983cb4f80e1811ae4c0f8 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -148,9 +148,6 @@ class CompressionScheduler(object):
                     name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, '', 1)
                     name = '.'.join(name_parts)
                     self.zeros_mask_dict[name].apply_mask(param)
-                else:
-                    raise
-
 
     def state_dict(self):
         """Returns the state of the scheduler as a :class:`dict`.
diff --git a/docs-src/docs/algo_quantization.md b/docs-src/docs/algo_quantization.md
index 00bf2fdef1133b855e3e90ce08bc8bc199930261..d2fbc065266dfd9284354b9db509bc9fd80f110d 100644
--- a/docs-src/docs/algo_quantization.md
+++ b/docs-src/docs/algo_quantization.md
@@ -29,6 +29,14 @@ This method requires training the model with quantization, as discussed [here](q
 - Gradients quantization as proposed in the paper is not supported yet.
 - The paper defines special handling for binary weights which isn't supported in Distiller yet.
 
+## PACT
+
+(As proposed in [PACT: Parameterized Clipping Activation for Quantized Neural Networks](https://arxiv.org/abs/1805.06085))
+
+This method is similar to DoReFa, but the upper clipping values, \(\alpha\), of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \(\alpha\) is shared per layer.
+
+This method requires training the model with quantization, as discussed [here](quantization/#training-with-quantization). Use the `PACTQuantizer` class to transform an existing model to a model suitable for training with quantization using PACT.
+
 ## WRPN
 
 (As proposed in [WRPN: Wide Reduced-Precision Networks](https://arxiv.org/abs/1709.01134))  
diff --git a/docs-src/docs/design.md b/docs-src/docs/design.md
index 62f66d23046736457338d74aa03a9bad6ba9ea2d..f6f2be507fe625a0cf83fea9ad28f9b3fb299f5a 100755
--- a/docs-src/docs/design.md
+++ b/docs-src/docs/design.md
@@ -76,13 +76,23 @@ The `Quantizer` class also provides an API to quantize the weights of all layers
 
 ### Training with Quantization
 
-The `Quantizer` class supports training with quantization in the loop, as described [here](quantization.md#training-with-quantization). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack":
+The `Quantizer` class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:
 
-1. The existing `torch.nn.Parameter`, e.g. `weights`, is replaced by a `torch.nn.Parameter` named `float_weight`.
-2. To maintain the existing functionality of the module, we then register a `buffer` in the module with the original name - `weights`.
-3. During training, `float_weight` will be passed to `param_quantization_fn` and the result will be stored in `weight`.
+1. Maintaining a full precision copy of the weights, as described [here](quantization.md#training-with-quantization). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": 
 
-**Important Note**: Since this process modifies the model's parameters, it must be done **before** an PyTorch `Optimizer` is created (this refers to any of the sub-classes defined [here](https://pytorch.org/docs/stable/optim.html#algorithms)).
+    1. The existing `torch.nn.Parameter`, e.g. `weights`, is replaced by a `torch.nn.Parameter` named `float_weight`.
+    2. To maintain the existing functionality of the module, we then register a `buffer` in the module with the original name - `weights`.
+    3. During training, `float_weight` will be passed to `param_quantization_fn` and the result will be stored in `weight`.
+
+2. In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the [PACT](algo_quantization.md#PACT) method, acitvations are clipped to a value \(\alpha\), which is a learned parameter per-layer
+
+To support these two cases, the `Quantizer` class also accepts an instance of a `torch.optim.Optimizer` (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters.   
+
+!!! Note "Optimizing New Parameters"
+    In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override `Quantizer._get_updated_optimizer_params_groups()`, and return the proper groups plus any desired hyper-parameter overrides.
+
+### Examples
 
 The base `Quantizer` class is implemented in `distiller/quantization/quantizer.py`.  
-For a simple sub-class implementing symmetric linear quantization, see `SymmetricLinearQuantizer` in `distiller/quantization/range_linear.py`. For examples of lower-precision methods using training with quantization see `DorefaQuantizer` and `WRPNQuantizer` in `distiller/quantization/clipped_linear.py`
+For a simple sub-class implementing symmetric linear quantization, see `SymmetricLinearQuantizer` in `distiller/quantization/range_linear.py`.  
+In `distiller/quantization/clipped_linear.py` there are examples of lower-precision methods which use training with quantization. Specifically, see `PACTQuantizer` for an example of overriding `Quantizer._get_updated_optimizer_params_groups()`.
diff --git a/docs/algo_quantization/index.html b/docs/algo_quantization/index.html
index 4888886cf16052e5bc0c194804e99aab71a0d424..2225104c6fc0dc4706e2513bf8edddd58c440de3 100644
--- a/docs/algo_quantization/index.html
+++ b/docs/algo_quantization/index.html
@@ -106,6 +106,8 @@
         
             <li><a class="toctree-l4" href="#dorefa">DoReFa</a></li>
         
+            <li><a class="toctree-l4" href="#pact">PACT</a></li>
+        
             <li><a class="toctree-l4" href="#wrpn">WRPN</a></li>
         
             <li><a class="toctree-l4" href="#symmetric-linear-quantization">Symmetric Linear Quantization</a></li>
@@ -195,6 +197,10 @@
 <li>Gradients quantization as proposed in the paper is not supported yet.</li>
 <li>The paper defines special handling for binary weights which isn't supported in Distiller yet.</li>
 </ul>
+<h2 id="pact">PACT</h2>
+<p>(As proposed in <a href="https://arxiv.org/abs/1805.06085">PACT: Parameterized Clipping Activation for Quantized Neural Networks</a>)</p>
+<p>This method is similar to DoReFa, but the upper clipping values, <script type="math/tex">\alpha</script>, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, <script type="math/tex">\alpha</script> is shared per layer.</p>
+<p>This method requires training the model with quantization, as discussed <a href="../quantization/#training-with-quantization">here</a>. Use the <code>PACTQuantizer</code> class to transform an existing model to a model suitable for training with quantization using PACT.</p>
 <h2 id="wrpn">WRPN</h2>
 <p>(As proposed in <a href="https://arxiv.org/abs/1709.01134">WRPN: Wide Reduced-Precision Networks</a>)  </p>
 <p>In this method, activations are clipped to <script type="math/tex">[0, 1]</script> and quantized as follows (<script type="math/tex">k</script> is the number of bits used for quantization):</p>
diff --git a/docs/design/index.html b/docs/design/index.html
index 9d256ff7541cc62eb6d4b58910a1ec0bd4c2b034..47497971ffcd83660ef52b1d5d23625f20823539 100644
--- a/docs/design/index.html
+++ b/docs/design/index.html
@@ -231,15 +231,29 @@ To execute the model transformation, call the <code>prepare_model</code> functio
 <h3 id="weights-quantization">Weights Quantization</h3>
 <p>The <code>Quantizer</code> class also provides an API to quantize the weights of all layers at once. To use it, the <code>param_quantization_fn</code> attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the <code>Quantizer</code> class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the <code>quantize_params</code> function can be called, which will iterate over all parameters and quantize them using <code>params_quantization_fn</code>.</p>
 <h3 id="training-with-quantization">Training with Quantization</h3>
-<p>The <code>Quantizer</code> class supports training with quantization in the loop, as described <a href="../quantization/index.html#training-with-quantization">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack":</p>
+<p>The <code>Quantizer</code> class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:</p>
+<ol>
+<li>
+<p>Maintaining a full precision copy of the weights, as described <a href="../quantization/index.html#training-with-quantization">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": </p>
 <ol>
 <li>The existing <code>torch.nn.Parameter</code>, e.g. <code>weights</code>, is replaced by a <code>torch.nn.Parameter</code> named <code>float_weight</code>.</li>
 <li>To maintain the existing functionality of the module, we then register a <code>buffer</code> in the module with the original name - <code>weights</code>.</li>
 <li>During training, <code>float_weight</code> will be passed to <code>param_quantization_fn</code> and the result will be stored in <code>weight</code>.</li>
 </ol>
-<p><strong>Important Note</strong>: Since this process modifies the model's parameters, it must be done <strong>before</strong> an PyTorch <code>Optimizer</code> is created (this refers to any of the sub-classes defined <a href="https://pytorch.org/docs/stable/optim.html#algorithms">here</a>).</p>
+</li>
+<li>
+<p>In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the <a href="../algo_quantization/index.html#PACT">PACT</a> method, acitvations are clipped to a value <script type="math/tex">\alpha</script>, which is a learned parameter per-layer</p>
+</li>
+</ol>
+<p>To support these two cases, the <code>Quantizer</code> class also accepts an instance of a <code>torch.optim.Optimizer</code> (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters.   </p>
+<div class="admonition note">
+<p class="admonition-title">Optimizing New Parameters</p>
+<p>In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override <code>Quantizer._get_updated_optimizer_params_groups()</code>, and return the proper groups plus any desired hyper-parameter overrides.</p>
+</div>
+<h3 id="examples">Examples</h3>
 <p>The base <code>Quantizer</code> class is implemented in <code>distiller/quantization/quantizer.py</code>.<br />
-For a simple sub-class implementing symmetric linear quantization, see <code>SymmetricLinearQuantizer</code> in <code>distiller/quantization/range_linear.py</code>. For examples of lower-precision methods using training with quantization see <code>DorefaQuantizer</code> and <code>WRPNQuantizer</code> in <code>distiller/quantization/clipped_linear.py</code></p>
+For a simple sub-class implementing symmetric linear quantization, see <code>SymmetricLinearQuantizer</code> in <code>distiller/quantization/range_linear.py</code>.<br />
+In <code>distiller/quantization/clipped_linear.py</code> there are examples of lower-precision methods which use training with quantization. Specifically, see <code>PACTQuantizer</code> for an example of overriding <code>Quantizer._get_updated_optimizer_params_groups()</code>.</p>
               
             </div>
           </div>
diff --git a/docs/index.html b/docs/index.html
index 2170e16e518e8e20ddec5e25b35da29eebd0fcdb..3fdb6e30888ec15897c32d55d6a14158dfe89b15 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -246,5 +246,5 @@ And of course, if we used a sparse or compressed representation, then we are red
 
 <!--
 MkDocs version : 0.17.2
-Build Date UTC : 2018-07-17 10:59:29
+Build Date UTC : 2018-07-22 11:48:56
 -->
diff --git a/docs/search/search_index.json b/docs/search/search_index.json
index f2425612ef706f0b10a590a48bc675df0a1a9b0b..be04e4340483834cb48cc927d45ce9357279ebc3 100644
--- a/docs/search/search_index.json
+++ b/docs/search/search_index.json
@@ -367,7 +367,7 @@
         }, 
         {
             "location": "/algo_quantization/index.html", 
-            "text": "Quantization Algorithms\n\n\nThe following quantization methods are currently implemented in Distiller:\n\n\nDoReFa\n\n\n(As proposed in \nDoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients\n)  \n\n\nIn this method, we first define the quantization function \nquantize_k\n, which takes a real value \na_f \\in [0, 1]\n and outputs a discrete-valued \na_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\}\n, where \nk\n is the number of bits used for quantization.\n\n\n\n\na_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right)\n\n\n\n\nActivations are clipped to the \n[0, 1]\n range and then quantized as follows:\n\n\n\n\nx_q = quantize_k(x_f)\n\n\n\n\nFor weights, we define the following function \nf\n, which takes an unbounded real valued input and outputs a real value in \n[0, 1]\n:\n\n\n\n\nf(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} \n\n\n\n\nNow we can use \nquantize_k\n to get quantized weight values, as follows:\n\n\n\n\nw_q = 2 quantize_k \\left( f(w_f) \\right) - 1\n\n\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nDorefaQuantizer\n class to transform an existing model to a model suitable for training with quantization using DoReFa.\n\n\nNotes:\n\n\n\n\nGradients quantization as proposed in the paper is not supported yet.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nWRPN\n\n\n(As proposed in \nWRPN: Wide Reduced-Precision Networks\n)  \n\n\nIn this method, activations are clipped to \n[0, 1]\n and quantized as follows (\nk\n is the number of bits used for quantization):\n\n\n\n\nx_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right)\n\n\n\n\nWeights are clipped to \n[-1, 1]\n and quantized as follows:\n\n\n\n\nw_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right)\n\n\n\n\nNote that \nk-1\n bits are used to quantize weights, leaving one bit for sign.\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nWRPNQuantizer\n class to transform an existing model to a model suitable for training with quantization using WRPN.\n\n\nNotes:\n\n\n\n\nThe paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of \nWRPNQuantizer\n at the moment. To experiment with this, modify your model implementation to have wider layers.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nSymmetric Linear Quantization\n\n\nIn this method, a float value is quantized by multiplying with a numeric constant (the \nscale factor\n), hence it is \nLinear\n. We use a signed integer to represent the quantized range, with no quantization bias (or \"offset\") used. As a result, the floating-point range considered for quantization is \nsymmetric\n with respect to zero.\n\nIn the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).\n\nLet us denote the original floating-point tensor by \nx_f\n, the quantized tensor by \nx_q\n, the scale factor by \nq_x\n and the number of bits used for quantization by \nn\n. Then, we get:\n\nq_x = \\frac{2^{n-1}-1}{\\max|x|}\n\n\nx_q = round(q_x x_f)\n\n(The \nround\n operation is round-to-nearest-integer)  \n\n\nLet's see how a \nconvolution\n or \nfully-connected (FC)\n layer is quantized using this method: (we denote input, output, weights and bias with  \nx, y, w\n and \nb\n respectively)\n\ny_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right)\n\n\ny_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right) \n\nNote how the bias has to be re-scaled to match the scale of the summation.\n\n\nImplementation\n\n\nWe've implemented \nconvolution\n and \nFC\n using this method.  \n\n\n\n\nThey are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the \nRangeLinearQuantParamLayerWrapper\n class.  \n\n\nAll other layers are unaffected and are executed using their original FP32 implementation.  \n\n\nTo automatically transform an existing model to a quantized model using this method, use the \nSymmetricLinearQuantizer\n class.\n\n\nFor weights and bias the scale factor is determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\").  \n\n\nImportant note:\n Currently, this method is implemented as \ninference only\n, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with \nn < 8\n is likely to lead to severe accuracy degradation for any non-trivial workload.", 
+            "text": "Quantization Algorithms\n\n\nThe following quantization methods are currently implemented in Distiller:\n\n\nDoReFa\n\n\n(As proposed in \nDoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients\n)  \n\n\nIn this method, we first define the quantization function \nquantize_k\n, which takes a real value \na_f \\in [0, 1]\n and outputs a discrete-valued \na_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\}\n, where \nk\n is the number of bits used for quantization.\n\n\n\n\na_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right)\n\n\n\n\nActivations are clipped to the \n[0, 1]\n range and then quantized as follows:\n\n\n\n\nx_q = quantize_k(x_f)\n\n\n\n\nFor weights, we define the following function \nf\n, which takes an unbounded real valued input and outputs a real value in \n[0, 1]\n:\n\n\n\n\nf(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} \n\n\n\n\nNow we can use \nquantize_k\n to get quantized weight values, as follows:\n\n\n\n\nw_q = 2 quantize_k \\left( f(w_f) \\right) - 1\n\n\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nDorefaQuantizer\n class to transform an existing model to a model suitable for training with quantization using DoReFa.\n\n\nNotes:\n\n\n\n\nGradients quantization as proposed in the paper is not supported yet.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nPACT\n\n\n(As proposed in \nPACT: Parameterized Clipping Activation for Quantized Neural Networks\n)\n\n\nThis method is similar to DoReFa, but the upper clipping values, \n\\alpha\n, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \n\\alpha\n is shared per layer.\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nPACTQuantizer\n class to transform an existing model to a model suitable for training with quantization using PACT.\n\n\nWRPN\n\n\n(As proposed in \nWRPN: Wide Reduced-Precision Networks\n)  \n\n\nIn this method, activations are clipped to \n[0, 1]\n and quantized as follows (\nk\n is the number of bits used for quantization):\n\n\n\n\nx_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right)\n\n\n\n\nWeights are clipped to \n[-1, 1]\n and quantized as follows:\n\n\n\n\nw_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right)\n\n\n\n\nNote that \nk-1\n bits are used to quantize weights, leaving one bit for sign.\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nWRPNQuantizer\n class to transform an existing model to a model suitable for training with quantization using WRPN.\n\n\nNotes:\n\n\n\n\nThe paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of \nWRPNQuantizer\n at the moment. To experiment with this, modify your model implementation to have wider layers.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nSymmetric Linear Quantization\n\n\nIn this method, a float value is quantized by multiplying with a numeric constant (the \nscale factor\n), hence it is \nLinear\n. We use a signed integer to represent the quantized range, with no quantization bias (or \"offset\") used. As a result, the floating-point range considered for quantization is \nsymmetric\n with respect to zero.\n\nIn the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).\n\nLet us denote the original floating-point tensor by \nx_f\n, the quantized tensor by \nx_q\n, the scale factor by \nq_x\n and the number of bits used for quantization by \nn\n. Then, we get:\n\nq_x = \\frac{2^{n-1}-1}{\\max|x|}\n\n\nx_q = round(q_x x_f)\n\n(The \nround\n operation is round-to-nearest-integer)  \n\n\nLet's see how a \nconvolution\n or \nfully-connected (FC)\n layer is quantized using this method: (we denote input, output, weights and bias with  \nx, y, w\n and \nb\n respectively)\n\ny_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right)\n\n\ny_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right) \n\nNote how the bias has to be re-scaled to match the scale of the summation.\n\n\nImplementation\n\n\nWe've implemented \nconvolution\n and \nFC\n using this method.  \n\n\n\n\nThey are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the \nRangeLinearQuantParamLayerWrapper\n class.  \n\n\nAll other layers are unaffected and are executed using their original FP32 implementation.  \n\n\nTo automatically transform an existing model to a quantized model using this method, use the \nSymmetricLinearQuantizer\n class.\n\n\nFor weights and bias the scale factor is determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\").  \n\n\nImportant note:\n Currently, this method is implemented as \ninference only\n, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with \nn < 8\n is likely to lead to severe accuracy degradation for any non-trivial workload.", 
             "title": "Quantization"
         }, 
         {
@@ -385,6 +385,11 @@
             "text": "Gradients quantization as proposed in the paper is not supported yet.  The paper defines special handling for binary weights which isn't supported in Distiller yet.", 
             "title": "Notes:"
         }, 
+        {
+            "location": "/algo_quantization/index.html#pact", 
+            "text": "(As proposed in  PACT: Parameterized Clipping Activation for Quantized Neural Networks )  This method is similar to DoReFa, but the upper clipping values,  \\alpha , of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation,  \\alpha  is shared per layer.  This method requires training the model with quantization, as discussed  here . Use the  PACTQuantizer  class to transform an existing model to a model suitable for training with quantization using PACT.", 
+            "title": "PACT"
+        }, 
         {
             "location": "/algo_quantization/index.html#wrpn", 
             "text": "(As proposed in  WRPN: Wide Reduced-Precision Networks )    In this method, activations are clipped to  [0, 1]  and quantized as follows ( k  is the number of bits used for quantization):   x_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right)   Weights are clipped to  [-1, 1]  and quantized as follows:   w_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right)   Note that  k-1  bits are used to quantize weights, leaving one bit for sign.  This method requires training the model with quantization, as discussed  here . Use the  WRPNQuantizer  class to transform an existing model to a model suitable for training with quantization using WRPN.", 
@@ -527,7 +532,7 @@
         }, 
         {
             "location": "/design/index.html", 
-            "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training.  The training skeleton looks like the pseudo code below.  The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n    compression_scheduler.on_epoch_begin(epoch)\n    train()\n    validate()\n    save_checkpoint()\n    compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n    For each training step:\n        compression_scheduler.on_minibatch_begin(epoch)\n        output = model(input_var)\n        loss = criterion(output, target_var)\n        compression_scheduler.before_backward_pass(epoch)\n        loss.backward()\n        optimizer.step()\n        compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm.  The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\nThe \nbits_overrides\n mapping is required to be an instance of \ncollections.OrderedDict\n (as opposed to just a simple Python \ndict\n). This is done in order to enable handling of overlapping name patterns.\n\n     So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.\n\n     The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nTraining with Quantization\n\n\nThe \nQuantizer\n class supports training with quantization in the loop, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\":\n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\nImportant Note\n: Since this process modifies the model's parameters, it must be done \nbefore\n an PyTorch \nOptimizer\n is created (this refers to any of the sub-classes defined \nhere\n).\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n. For examples of lower-precision methods using training with quantization see \nDorefaQuantizer\n and \nWRPNQuantizer\n in \ndistiller/quantization/clipped_linear.py", 
+            "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training.  The training skeleton looks like the pseudo code below.  The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n    compression_scheduler.on_epoch_begin(epoch)\n    train()\n    validate()\n    save_checkpoint()\n    compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n    For each training step:\n        compression_scheduler.on_minibatch_begin(epoch)\n        output = model(input_var)\n        loss = criterion(output, target_var)\n        compression_scheduler.before_backward_pass(epoch)\n        loss.backward()\n        optimizer.step()\n        compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm.  The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\nThe \nbits_overrides\n mapping is required to be an instance of \ncollections.OrderedDict\n (as opposed to just a simple Python \ndict\n). This is done in order to enable handling of overlapping name patterns.\n\n     So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.\n\n     The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nTraining with Quantization\n\n\nThe \nQuantizer\n class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:\n\n\n\n\n\n\nMaintaining a full precision copy of the weights, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\": \n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\n\n\n\n\nIn addition, some quantization methods may introduce additional learned parameters to the model. For example, in the \nPACT\n method, acitvations are clipped to a value \n\\alpha\n, which is a learned parameter per-layer\n\n\n\n\n\n\nTo support these two cases, the \nQuantizer\n class also accepts an instance of a \ntorch.optim.Optimizer\n (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters.   \n\n\n\n\nOptimizing New Parameters\n\n\nIn cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override \nQuantizer._get_updated_optimizer_params_groups()\n, and return the proper groups plus any desired hyper-parameter overrides.\n\n\n\n\nExamples\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n.\n\nIn \ndistiller/quantization/clipped_linear.py\n there are examples of lower-precision methods which use training with quantization. Specifically, see \nPACTQuantizer\n for an example of overriding \nQuantizer._get_updated_optimizer_params_groups()\n.", 
             "title": "Design"
         }, 
         {
@@ -562,8 +567,13 @@
         }, 
         {
             "location": "/design/index.html#training-with-quantization", 
-            "text": "The  Quantizer  class supports training with quantization in the loop, as described  here . This is enabled by setting  train_with_fp_copy=True  in the  Quantizer  constructor. At model transformation, in each module that has parameters that should be quantized, a new  torch.nn.Parameter  is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module  is not  created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\":   The existing  torch.nn.Parameter , e.g.  weights , is replaced by a  torch.nn.Parameter  named  float_weight .  To maintain the existing functionality of the module, we then register a  buffer  in the module with the original name -  weights .  During training,  float_weight  will be passed to  param_quantization_fn  and the result will be stored in  weight .   Important Note : Since this process modifies the model's parameters, it must be done  before  an PyTorch  Optimizer  is created (this refers to any of the sub-classes defined  here ).  The base  Quantizer  class is implemented in  distiller/quantization/quantizer.py . \nFor a simple sub-class implementing symmetric linear quantization, see  SymmetricLinearQuantizer  in  distiller/quantization/range_linear.py . For examples of lower-precision methods using training with quantization see  DorefaQuantizer  and  WRPNQuantizer  in  distiller/quantization/clipped_linear.py", 
+            "text": "The  Quantizer  class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:    Maintaining a full precision copy of the weights, as described  here . This is enabled by setting  train_with_fp_copy=True  in the  Quantizer  constructor. At model transformation, in each module that has parameters that should be quantized, a new  torch.nn.Parameter  is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module  is not  created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\":    The existing  torch.nn.Parameter , e.g.  weights , is replaced by a  torch.nn.Parameter  named  float_weight .  To maintain the existing functionality of the module, we then register a  buffer  in the module with the original name -  weights .  During training,  float_weight  will be passed to  param_quantization_fn  and the result will be stored in  weight .     In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the  PACT  method, acitvations are clipped to a value  \\alpha , which is a learned parameter per-layer    To support these two cases, the  Quantizer  class also accepts an instance of a  torch.optim.Optimizer  (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters.      Optimizing New Parameters  In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override  Quantizer._get_updated_optimizer_params_groups() , and return the proper groups plus any desired hyper-parameter overrides.", 
             "title": "Training with Quantization"
+        }, 
+        {
+            "location": "/design/index.html#examples", 
+            "text": "The base  Quantizer  class is implemented in  distiller/quantization/quantizer.py . \nFor a simple sub-class implementing symmetric linear quantization, see  SymmetricLinearQuantizer  in  distiller/quantization/range_linear.py . \nIn  distiller/quantization/clipped_linear.py  there are examples of lower-precision methods which use training with quantization. Specifically, see  PACTQuantizer  for an example of overriding  Quantizer._get_updated_optimizer_params_groups() .", 
+            "title": "Examples"
         }
     ]
 }
\ No newline at end of file
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 8ca15b0fb16a366481fcb41abf4ae7008c682c59..84165849ca515ac50dfdeb5582f468852bbf2e03 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -4,7 +4,7 @@
     
     <url>
      <loc>/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -12,7 +12,7 @@
     
     <url>
      <loc>/install/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -20,7 +20,7 @@
     
     <url>
      <loc>/usage/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -28,7 +28,7 @@
     
     <url>
      <loc>/schedule/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -37,19 +37,19 @@
         
     <url>
      <loc>/pruning/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
         
     <url>
      <loc>/regularization/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
         
     <url>
      <loc>/quantization/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
         
@@ -59,13 +59,13 @@
         
     <url>
      <loc>/algo_pruning/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
         
     <url>
      <loc>/algo_quantization/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
         
@@ -74,7 +74,7 @@
     
     <url>
      <loc>/model_zoo/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -82,7 +82,7 @@
     
     <url>
      <loc>/jupyter/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
@@ -90,7 +90,7 @@
     
     <url>
      <loc>/design/index.html</loc>
-     <lastmod>2018-07-17</lastmod>
+     <lastmod>2018-07-22</lastmod>
      <changefreq>daily</changefreq>
     </url>
     
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index fe45603b164c8992532b5a5e476f3afa35e947c2..36cee2edcee1c0734712810d90970ed28e1c4d85 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -312,8 +312,11 @@ def main():
         # The main use-case for this sample application is CNN compression. Compression
         # requires a compression schedule configuration file in YAML.
         compression_scheduler = distiller.file_config(model, optimizer, args.compress)
+        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
+        model.cuda()
 
     best_epoch = start_epoch
+
     for epoch in range(start_epoch, start_epoch + args.epochs):
         # This is the main training loop.
         msglogger.info('\n')
diff --git a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml
index 792d971ae469b83d2de306c6c510ec9404d8c6eb..63aac98002be1b280e491e65faa12ec7454adcd5 100644
--- a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml
+++ b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml
@@ -1,3 +1,21 @@
+
+# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1
+# --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0
+
+
+#2018-07-18 12:25:56,477 - --- validate (epoch=199)-----------
+#2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch)
+#2018-07-18 12:25:57,810 - Epoch: [199][   50/   78]    Loss 0.312961    Top1 92.140625    Top5 99.765625
+#2018-07-18 12:25:58,402 - ==> Top1: 92.270    Top5: 99.800    Loss: 0.307
+#
+#2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560   Epoch: 127
+#2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar
+#2018-07-18 12:25:58,418 - --- test ---------------------
+#2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch)
+#2018-07-18 12:25:59,664 - Test: [   50/   78]    Loss 0.312961    Top1 92.140625    Top5 99.765625
+#2018-07-18 12:26:00,248 - ==> Top1: 92.270    Top5: 99.800    Loss: 0.307
+
+
 lr_schedulers:
   training_lr:
     class: MultiStepMultiGammaLR
diff --git a/examples/quantization/preact_resnet20_cifar_pact.yaml b/examples/quantization/preact_resnet20_cifar_pact.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8ef1858df708570a02e5b124e04b7c94c5b95f38
--- /dev/null
+++ b/examples/quantization/preact_resnet20_cifar_pact.yaml
@@ -0,0 +1,57 @@
+
+# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1
+# --epochs 200 --compress=../quantization/preact_resnet20_cifar_pact.yaml --out-dir="logs/" --wd=0.0002 --vs=0
+
+
+#2018-07-18 17:28:56,710 - --- validate (epoch=199)-----------
+#2018-07-18 17:28:56,710 - 10000 samples (128 per mini-batch)
+#2018-07-18 17:28:58,070 - Epoch: [199][   50/   78]    Loss 0.349229    Top1 91.140625    Top5 99.671875
+#2018-07-18 17:28:58,670 - ==> Top1: 91.440    Top5: 99.680    Loss: 0.348
+#
+#2018-07-18 17:28:58,671 - ==> Best validation Top1: 91.860   Epoch: 147
+#2018-07-18 17:28:58,672 - Saving checkpoint to: logs/checkpoint.pth.tar
+#2018-07-18 17:28:58,687 - --- test ---------------------
+#2018-07-18 17:28:58,687 - 10000 samples (128 per mini-batch)
+#2018-07-18 17:29:00,006 - Test: [   50/   78]    Loss 0.349229    Top1 91.140625    Top5 99.671875
+#2018-07-18 17:29:00,560 - ==> Top1: 91.440    Top5: 99.680    Loss: 0.348
+
+
+quantizers:
+  pact_quantizer:
+    class: PACTQuantizer
+    act_clip_init_val: 8.0
+    bits_activations: 4
+    bits_weights: 3
+    bits_overrides:
+    # Don't quantize first and last layers
+      conv1:
+        wts: null
+        acts: null
+      layer1.0.pre_relu:
+        wts: null
+        acts: null
+      final_relu:
+        wts: null
+        acts: null
+      fc:
+        wts: null
+        acts: null
+
+lr_schedulers:
+  training_lr:
+    class: MultiStepLR
+    milestones: [60, 120]
+    gammas: 0.1
+
+policies:
+    - quantizer:
+        instance_name: pact_quantizer
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
+
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 121
+      frequency: 1
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index e8ea1272c1a81b21028da3a7048fdfb2718ba0f1..0409335d2642d77d501919aad049a22ad6ac7b18 100644
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -106,8 +106,8 @@ test_configs = [
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
                DS_CIFAR, accuracy_checker, [91.620, 99.630]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
-               format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_dorefa_test.yaml')),
-               DS_CIFAR, accuracy_checker, [41.190, 91.160])
+               format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
+               DS_CIFAR, accuracy_checker, [48.290, 94.460])
 ]
 
 
diff --git a/tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml b/tests/full_flow_tests/preact_resnet20_cifar_pact_test.yaml
similarity index 64%
rename from tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml
rename to tests/full_flow_tests/preact_resnet20_cifar_pact_test.yaml
index a9870530bc6be26214b85610b417ea89ebd3a324..1a83aa435303dfda84e6215357af5763693ceba8 100644
--- a/tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml
+++ b/tests/full_flow_tests/preact_resnet20_cifar_pact_test.yaml
@@ -1,10 +1,11 @@
 quantizers:
-  dorefa_quantizer:
-    class: DorefaQuantizer
-    bits_activations: 8
+  pact_quantizer:
+    class: PACTQuantizer
+    act_clip_init_val: 8.0
+    bits_activations: 4
     bits_weights: 3
     bits_overrides:
-    # Don't quantize first and last layer
+    # Don't quantize first and last layers
       conv1:
         wts: null
         acts: null
@@ -20,13 +21,13 @@ quantizers:
 
 lr_schedulers:
   training_lr:
-    class: MultiStepMultiGammaLR
-    milestones: [80, 120, 160]
-    gammas: [0.1, 0.1, 0.2]
+    class: MultiStepLR
+    milestones: [60, 120]
+    gammas: 0.1
 
 policies:
     - quantizer:
-        instance_name: dorefa_quantizer
+        instance_name: pact_quantizer
       starting_epoch: 0
       ending_epoch: 200
       frequency: 1
@@ -34,5 +35,5 @@ policies:
     - lr_scheduler:
         instance_name: training_lr
       starting_epoch: 0
-      ending_epoch: 161
+      ending_epoch: 121
       frequency: 1
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index fece996b03bca6b1a5c5dc8fb4b4260a036e344a..811e46c912283fb61deae0bad76a0cc60e4a6a96 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -95,9 +95,9 @@ def dummy_quantize_params(param, num_bits):
 
 
 class DummyQuantizer(Quantizer):
-    def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
+    def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
                  quantize_bias=False, train_with_fp_copy=False):
-        super(DummyQuantizer, self).__init__(model, bits_activations, bits_weights, bits_overrides, quantize_bias,
+        super(DummyQuantizer, self).__init__(model, optimizer, bits_activations, bits_weights, bits_overrides, quantize_bias,
                                              train_with_fp_copy)
 
         self.replacement_factory[nn.Conv2d] = lambda module, name, qbits_map: DummyWrapperLayer(module, qbits_map[name])
@@ -143,6 +143,12 @@ def fixture_model():
     return DummyModel()
 
 
+# TODO: Test optimizer modifications in 'test_model_prep'
+@pytest.fixture(name='optimizer')
+def fixture_optimizer(model):
+    return torch.optim.SGD(model.parameters(), lr=0.1)
+
+
 @pytest.fixture(name='train_with_fp_copy', params=[False, True], ids=['fp_copy_off', 'fp_copy_on'])
 def fixture_train_with_fp_copy(request):
     return request.param
@@ -213,7 +219,7 @@ def test_overrides_ordered_dict(model):
                                             #             never actually matched
     ]
 )
-def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides,
+def test_model_prep(model, optimizer, qbits, bits_overrides, explicit_expected_overrides,
                     train_with_fp_copy, quantize_bias, parallel):
     if parallel:
         model = torch.nn.DataParallel(model)
@@ -223,7 +229,7 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides,
     expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
 
     # Initialize Quantizer
-    q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts,
+    q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts,
                        bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy,
                        quantize_bias=quantize_bias)
 
@@ -290,12 +296,12 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides,
           'sub1.conv1': QBits(8, 4), 'sub1.conv2': QBits(4, 4), 'sub2.conv1': QBits(8, 4), 'sub2.conv2': QBits(4, 4)}),
     ]
 )
-def test_param_quantization(model, qbits, bits_overrides, explicit_expected_overrides,
+def test_param_quantization(model, optimizer, qbits, bits_overrides, explicit_expected_overrides,
                             train_with_fp_copy, quantize_bias):
     # Build expected QBits
     expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
 
-    q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts,
+    q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts,
                        bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy,
                        quantize_bias=quantize_bias)
     q.prepare_model()