diff --git a/.gitignore b/.gitignore
index b820feca28116dca2f6b739ac23ce4f3adc78f29..8f641a310621fb58fd6d9b2cade1474b9f6e6588 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@
 *.pyc
 __pycache__/
 .pytest_cache
+.cache
 *.tar
 site/
 env/
diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 83ae242b4c0d8c9281e3d87226a0d005ab79341a..16c9c98b90a0e2b49e9df6d2b9db302dd5849f19 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -28,7 +28,8 @@ import distiller
 msglogger = logging.getLogger()
 
 
-def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=None, is_best=False, name=None):
+def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
+                    best_top1=None, is_best=False, name=None, dir='.'):
     """Save a pytorch training checkpoint
 
     Args:
@@ -40,10 +41,16 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non
         best_top1: the best top1 score seen so far
         is_best: True if this is the best (top1 accuracy) model so far
         name: the name of the checkpoint file
+        dir: directory in which to save the checkpoint
     """
     msglogger.info("Saving checkpoint")
+    if not os.path.isdir(dir):
+        msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir)))
+        exit(1)
     filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar'
+    fullpath = os.path.join(dir, filename)
     filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar'
+    fullpath_best = os.path.join(dir, filename_best)
     checkpoint = {}
     checkpoint['epoch'] = epoch
     checkpoint['arch'] =  arch
@@ -56,10 +63,12 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non
         checkpoint['compression_sched'] = scheduler.state_dict()
     if hasattr(model, 'thinning_recipes'):
         checkpoint['thinning_recipes'] = model.thinning_recipes
+    if hasattr(model, 'quantizer_metadata'):
+        checkpoint['quantizer_metadata'] = model.quantizer_metadata
 
-    torch.save(checkpoint, filename)
+    torch.save(checkpoint, fullpath)
     if is_best:
-        shutil.copyfile(filename, filename_best)
+        shutil.copyfile(fullpath, fullpath_best)
 
 
 def load_checkpoint(model, chkpt_file, optimizer=None):
@@ -86,20 +95,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
             compression_scheduler.load_state_dict(checkpoint['compression_sched'])
             msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
                            checkpoint['epoch'])
+        else:
+            msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
 
         if 'thinning_recipes' in checkpoint:
             if 'compression_sched' not in checkpoint:
-                raise KeyError("Found thinning_recipes key, but missing mandatoy key compression_sched")
+                raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
             msglogger.info("Loaded a thinning recipe from the checkpoint")
             # Cache the recipes in case we need them later
             model.thinning_recipes = checkpoint['thinning_recipes']
             distiller.execute_thinning_recipes_list(model,
                                               compression_scheduler.zeros_mask_dict,
                                               model.thinning_recipes)
-        else:
-            msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
-            msglogger.info("=> loaded checkpoint '%s' (epoch %d)",
-                           chkpt_file, checkpoint['epoch'])
+
+        if 'quantizer_metadata' in checkpoint:
+            msglogger.info('Loaded quantizer metadata from the checkpoint')
+            qmd = checkpoint['quantizer_metadata']
+            quantizer = qmd['type'](model, **qmd['params'])
+            quantizer.prepare_model()
+
+        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch'])
 
         model.load_state_dict(checkpoint['state_dict'])
         return model, compression_scheduler, start_epoch
diff --git a/apputils/data_loaders.py b/apputils/data_loaders.py
index 41f7bd867dc361a5c76fc9dad1cd2e5be1044c6f..0771bef0f693b59503b057c9eeeeaa2325283631 100755
--- a/apputils/data_loaders.py
+++ b/apputils/data_loaders.py
@@ -29,7 +29,7 @@ import numpy as np
 DATASETS_NAMES = ['imagenet', 'cifar10']
 
 
-def load_data(dataset, data_dir, batch_size, workers, deterministic=False):
+def load_data(dataset, data_dir, batch_size, workers, valid_size=0.1, deterministic=False):
     """Load a dataset.
 
     Args:
@@ -37,14 +37,15 @@ def load_data(dataset, data_dir, batch_size, workers, deterministic=False):
         data_dir: the directory where the datset resides
         batch_size: the batch size
         workers: the number of worker threads to use for loading the data
+        valid_size: portion of training dataset to set aside for validation
         deterministic: set to True if you want the data loading process to be deterministic.
           Note that deterministic data loading suffers from poor performance.
     """
     assert dataset in DATASETS_NAMES
     if dataset == 'cifar10':
-        return cifar10_load_data(data_dir, batch_size, workers, deterministic=deterministic)
+        return cifar10_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic)
     if dataset == 'imagenet':
-        return imagenet_load_data(data_dir, batch_size, workers, deterministic=deterministic)
+        return imagenet_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic)
     print("FATAL ERROR: load_data does not support dataset %s" % dataset)
     exit(1)
 
@@ -73,7 +74,7 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi
     We transform them to Tensors of normalized range [-1, 1]
     https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py
 
-    Data augmentation: 4 pixels are padded on each side, and a 32x32 crop is randomly sampled
+    Data augmentation: 4 pixels are padded on each side, and a 32x32 crop is randomly sampled
     from the padded image or its horizontal flip.
     This is similar to [1] and some other work that use CIFAR10.
 
@@ -103,7 +104,6 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi
 
     train_idx, valid_idx = indices[split:], indices[:split]
     train_sampler = SubsetRandomSampler(train_idx)
-    valid_sampler = SubsetRandomSampler(valid_idx)
 
     worker_init_fn = __deterministic_worker_init_fn if deterministic else None
 
@@ -112,10 +112,13 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi
                                                num_workers=num_workers, pin_memory=True,
                                                worker_init_fn=worker_init_fn)
 
-    valid_loader = torch.utils.data.DataLoader(train_dataset,
-                                               batch_size=batch_size, sampler=valid_sampler,
-                                               num_workers=num_workers, pin_memory=True,
-                                               worker_init_fn=worker_init_fn)
+    valid_loader = None
+    if split > 0:
+        valid_sampler = SubsetRandomSampler(valid_idx)
+        valid_loader = torch.utils.data.DataLoader(train_dataset,
+                                                   batch_size=batch_size, sampler=valid_sampler,
+                                                   num_workers=num_workers, pin_memory=True,
+                                                   worker_init_fn=worker_init_fn)
 
     testset = datasets.CIFAR10(root=data_dir, train=False,
                                download=True, transform=transform_test)
@@ -125,7 +128,9 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi
             num_workers=num_workers, pin_memory=True)
 
     input_shape = __image_size(train_dataset)
-    return train_loader, valid_loader, test_loader, input_shape
+
+    # If validation split was 0 we use the test set as the validation set
+    return train_loader, valid_loader or test_loader, test_loader, input_shape
 
 
 def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, deterministic=False):
@@ -159,7 +164,6 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ
 
     train_idx, valid_idx = indices[split:], indices[:split]
     train_sampler = SubsetRandomSampler(train_idx)
-    valid_sampler = SubsetRandomSampler(valid_idx)
 
     input_shape = __image_size(train_dataset)
 
@@ -170,10 +174,13 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ
                                                num_workers=num_workers, pin_memory=True,
                                                worker_init_fn=worker_init_fn)
 
-    valid_loader = torch.utils.data.DataLoader(train_dataset,
-                                               batch_size=batch_size, sampler=valid_sampler,
-                                               num_workers=num_workers, pin_memory=True,
-                                               worker_init_fn=worker_init_fn)
+    valid_loader = None
+    if split > 0:
+        valid_sampler = SubsetRandomSampler(valid_idx)
+        valid_loader = torch.utils.data.DataLoader(train_dataset,
+                                                   batch_size=batch_size, sampler=valid_sampler,
+                                                   num_workers=num_workers, pin_memory=True,
+                                                   worker_init_fn=worker_init_fn)
 
     test_loader = torch.utils.data.DataLoader(
         datasets.ImageFolder(test_dir, transforms.Compose([
@@ -185,4 +192,5 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ
         batch_size=batch_size, shuffle=False,
         num_workers=num_workers, pin_memory=True)
 
-    return train_loader, valid_loader, test_loader, input_shape
+    # If validation split was 0 we use the test set as the validation set
+    return train_loader, valid_loader or test_loader, test_loader, input_shape
diff --git a/apputils/execution_env.py b/apputils/execution_env.py
index 52441828263f903b00135dd448e0ffee36739913..09d079d5bf9fb58177b2cd41d6ca70cbbbee3961 100755
--- a/apputils/execution_env.py
+++ b/apputils/execution_env.py
@@ -58,7 +58,6 @@ def log_execution_env_state(app_args, gitroot='.'):
 
         if repo.is_dirty():
             logger.debug("Git is dirty")
-            #repo.index.diff(None)
         try:
             branch_name = repo.active_branch.name
         except TypeError:
@@ -80,7 +79,7 @@ def log_execution_env_state(app_args, gitroot='.'):
     logger.debug("App args: %s", app_args)
 
 
-def config_pylogger(log_cfg_file, experiment_name):
+def config_pylogger(log_cfg_file, experiment_name, output_dir='logs'):
     """Configure the Python logger.
 
     For each execution of the application, we'd like to create a unique log directory.
@@ -90,11 +89,11 @@ def config_pylogger(log_cfg_file, experiment_name):
     TensorBoard, for example.
     """
     timestr = time.strftime("%Y.%m.%d-%H%M%S")
-    filename = timestr if experiment_name is None else experiment_name + '___' + timestr
-    logdir =  os.path.join('./logs', filename)
+    exp_full_name = timestr if experiment_name is None else experiment_name + '___' + timestr
+    logdir = os.path.join(output_dir, exp_full_name)
     if not os.path.exists(logdir):
         os.makedirs(logdir)
-    log_filename = os.path.join(logdir, filename + '.log')
+    log_filename = os.path.join(logdir, exp_full_name + '.log')
     if os.path.isfile(log_cfg_file):
         logging.config.fileConfig(log_cfg_file, defaults={'logfilename': log_filename})
     msglogger = logging.getLogger()
diff --git a/distiller/__init__.py b/distiller/__init__.py
index 448f08f8dd29812dc3d169e1a60790363fcab8f5..a3213b564dea9d592068f5c2ef9f35fb92558e17 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -16,7 +16,7 @@
 
 from .utils import *
 from .thresholding import GroupThresholdMixin, threshold_mask
-from .config import fileConfig, dictConfig
+from .config import file_config, dict_config
 from .model_summaries import *
 from .scheduler import *
 from .sensitivity import *
@@ -25,7 +25,7 @@ from .policy import *
 from .thinning import *
 
 #del utils
-del dictConfig
+del dict_config
 del thinning
 #del model_summaries
 #del scheduler
diff --git a/distiller/config.py b/distiller/config.py
index 21ede1d52ba8f24351327d99616364afe65e4026..a5cc9cd76a82b7c90e873f9d27a7cafc5244742c 100755
--- a/distiller/config.py
+++ b/distiller/config.py
@@ -33,27 +33,37 @@ When a YAML file is loaded, its dictionary is extracted and passed to ```dictCon
 """
 
 import logging
+from collections import OrderedDict
 import yaml
 import json
 import inspect
-import distiller
 from torch.optim.lr_scheduler import *
 import distiller
 from distiller.thinning import *
 from distiller.pruning import *
 from distiller.regularization import L1Regularizer, GroupLassoRegularizer
 from distiller.learning_rate import *
-logger = logging.getLogger("app_cfg")
+from distiller.quantization import *
+
+msglogger = logging.getLogger()
+app_cfg_logger = logging.getLogger("app_cfg")
+
 
-def dictConfig(model, optimizer, schedule, sched_dict, logger):
-    logger.debug(json.dumps(sched_dict, indent=1))
+def dict_config(model, optimizer, sched_dict):
+    app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2))
+
+    schedule = distiller.CompressionScheduler(model)
 
     pruners = __factory('pruners', model, sched_dict)
     regularizers = __factory('regularizers', model, sched_dict)
-    lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer)
+    quantizers = __factory('quantizers', model, sched_dict)
+    if len(quantizers) > 1:
+        print("\nError: Multiple Quantizers not supported")
+        exit(1)
     extensions = __factory('extensions', model, sched_dict)
 
     try:
+        lr_policies = []
         for policy_def in sched_dict['policies']:
             policy = None
             if 'pruner' in policy_def:
@@ -76,11 +86,23 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger):
                 else:
                     policy = distiller.RegularizationPolicy(regularizer, **args)
 
+            elif 'quantizer' in policy_def:
+                instance_name, args = __policy_params(policy_def, 'quantizer')
+                assert instance_name in quantizers, "Quantizer {} was not defined in the list of quantizers".format(instance_name)
+                quantizer = quantizers[instance_name]
+                policy = distiller.QuantizationPolicy(quantizer)
+
+                # Quantizers for training modify the models parameters, need to update the optimizer
+                if quantizer.train_with_fp_copy:
+                    optimizer_type = type(optimizer)
+                    new_optimizer = optimizer_type(model.parameters(), **optimizer.defaults)
+                    optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
+
             elif 'lr_scheduler' in policy_def:
-                instance_name, args = __policy_params(policy_def, 'lr_scheduler')
-                assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(instance_name)
-                lr_scheduler = lr_schedulers[instance_name]
-                policy = distiller.LRPolicy(lr_scheduler)
+                # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain
+                # a quantization policy was initialized (if exists)
+                lr_policies.append(policy_def)
+                continue
 
             elif 'extension' in policy_def:
                 instance_name, args = __policy_params(policy_def, 'extension')
@@ -92,12 +114,18 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger):
                 print("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" % policy_def)
                 exit(1)
 
-            if 'epochs' in policy_def:
-                schedule.add_policy(policy, epochs=policy_def['epochs'])
-            else:
-                schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
-                                            ending_epoch=policy_def['ending_epoch'],
-                                            frequency=policy_def['frequency'])
+            add_policy_to_scheduler(policy, policy_def, schedule)
+
+        # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
+        lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer)
+        for policy_def in lr_policies:
+            instance_name, args = __policy_params(policy_def, 'lr_scheduler')
+            assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
+                instance_name)
+            lr_scheduler = lr_schedulers[instance_name]
+            policy = distiller.LRPolicy(lr_scheduler)
+            add_policy_to_scheduler(policy, policy_def, schedule)
+
     except AssertionError:
         # propagate the assertion information
         raise
@@ -108,12 +136,23 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger):
 
     return schedule
 
-def fileConfig(model, optimizer, schedule, filename, logger):
+
+def add_policy_to_scheduler(policy, policy_def, schedule):
+    if 'epochs' in policy_def:
+        schedule.add_policy(policy, epochs=policy_def['epochs'])
+    else:
+        schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
+                            ending_epoch=policy_def['ending_epoch'],
+                            frequency=policy_def['frequency'])
+
+
+def file_config(model, optimizer, filename):
     """Read the schedule from file"""
     with open(filename, 'r') as stream:
+        msglogger.info('Reading compression schedule from: %s', filename)
         try:
-            sched_dict = yaml.load(stream)
-            dictConfig(model, optimizer, schedule, sched_dict, logger)
+            sched_dict = yaml_ordered_load(stream)
+            return dict_config(model, optimizer, sched_dict)
         except yaml.YAMLError as exc:
             print("\nFATAL Parsing error while parsing the pruning schedule configuration file %s" % filename)
             exit(1)
@@ -165,3 +204,22 @@ def __policy_params(policy_def, type):
     name = policy_def[type]['instance_name']
     args = policy_def[type].get('args', None)
     return name, args
+
+
+def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict):
+    """
+    Function to load YAML file using an OrderedDict
+    See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
+    """
+    class OrderedLoader(Loader):
+        pass
+
+    def construct_mapping(loader, node):
+        loader.flatten_mapping(node)
+        return object_pairs_hook(loader.construct_pairs(node))
+
+    OrderedLoader.add_constructor(
+        yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
+        construct_mapping)
+
+    return yaml.load(stream, OrderedLoader)
diff --git a/distiller/learning_rate.py b/distiller/learning_rate.py
index 32325ee552dd0a712d2834c9b5830037cbad129c..657d0181c40d331a20d2d1ff23792ef74e904aed 100644
--- a/distiller/learning_rate.py
+++ b/distiller/learning_rate.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+from bisect import bisect_right
 from torch.optim.lr_scheduler import _LRScheduler
 
 
@@ -36,4 +37,22 @@ class PolynomialLR(_LRScheduler):
     def get_lr(self):
         # base_lr * (1 - iter/max_iter) ^ (power)
         return [base_lr * (1 - self.last_epoch / self.T_max) ** self.power
-                for base_lr in self.base_lrs]
\ No newline at end of file
+                for base_lr in self.base_lrs]
+
+
+class MultiStepMultiGammaLR(_LRScheduler):
+    def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
+        if not list(milestones) == sorted(milestones):
+            raise ValueError('Milestones should be a list of'
+                             ' increasing integers. Got {}', milestones)
+
+        self.milestones = milestones
+        self.multiplicative_gammas = [1]
+        for idx, gamma in enumerate(gammas):
+            self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx])
+
+        super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        idx = bisect_right(self.milestones, self.last_epoch)
+        return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs]
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index b9edddcc28ebf3e778c63d883c754c8e45536fa9..908f3d667f8388d574ce1a4be52fcd8849e09773 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -31,12 +31,12 @@ import torch.optim
 import distiller
 msglogger = logging.getLogger()
 
-__all__ = ['model_summary', 'optimizer_summary',                        \
-           'weights_sparsity_summary', 'weights_sparsity_tbl_summary',  \
+__all__ = ['model_summary',
+           'weights_sparsity_summary', 'weights_sparsity_tbl_summary',
            'model_performance_summary', 'model_performance_tbl_summary']
 
 from .data_loggers import PythonLogger, CsvLogger
-def model_summary(model, optimizer, what, dataset=None):
+def model_summary(model, what, dataset=None):
     if what == 'sparsity':
         pylogger = PythonLogger(msglogger)
         csvlogger = CsvLogger('weights.csv')
@@ -55,8 +55,6 @@ def model_summary(model, optimizer, what, dataset=None):
         print(t)
         print("Total MACs: " + "{:,}".format(total_macs))
 
-    elif what == 'optimizer':
-        optimizer_summary(optimizer)
     elif what == 'model':
         # print the simple form of the model
         print(model)
@@ -72,21 +70,6 @@ def model_summary(model, optimizer, what, dataset=None):
                 nodes.append([name, module.__class__.__name__])
         print(tabulate(nodes, headers=['Name', 'Type']))
 
-def optimizer_summary(optimizer):
-    assert isinstance(optimizer, torch.optim.SGD)
-    lr = optimizer.param_groups[0]['lr']
-    weight_decay = optimizer.param_groups[0]['weight_decay']
-    momentum = optimizer.param_groups[0]['momentum']
-    dampening = optimizer.param_groups[0]['dampening']
-    nesterov = optimizer.param_groups[0]['nesterov']
-
-    msglogger.info('Optimizer:\n'
-                   '\tmomentum={}'
-                   '\tL2={}'
-                   '\tLR={}'
-                   '\tdampening={}'
-                   '\tnesterov={}'.format(momentum, weight_decay, lr, dampening, nesterov))
-
 
 def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]):
 
diff --git a/distiller/policy.py b/distiller/policy.py
index 595f35559076ac9a66ff8568f560c8c1908b42e3..ca3f1c1a8de2b9fa6da9b84b75c786c85ff5938a 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -20,10 +20,12 @@
 - RegularizationPolicy: regulization scheduling
 - LRPolicy: learning-rate decay scheduling
 """
+import torch
+
 import logging
 msglogger = logging.getLogger()
 
-__all__ = ['PruningPolicy', 'RegularizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy']
+__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy']
 
 class ScheduledTrainingPolicy(object):
     """ Base class for all scheduled training policies.
@@ -130,3 +132,16 @@ class LRPolicy(ScheduledTrainingPolicy):
 
     def on_epoch_begin(self, model, zeros_mask_dict, meta):
         self.lr_scheduler.step()
+
+
+class QuantizationPolicy(ScheduledTrainingPolicy):
+    def __init__(self, quantizer):
+        super(QuantizationPolicy, self).__init__()
+        self.quantizer = quantizer
+        self.quantizer.prepare_model()
+        self.quantizer.quantize_params()
+
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+        # After parameters update, quantize the parameters again
+        # (Doing this here ensures the model parameters are quantized at training completion (and at validation time)
+        self.quantizer.quantize_params()
diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py
index cdb45b7c6bd843eb0290b0cd02008b89df45879b..e50edbfe02fef5ca6eaea7227dfe21fa20c175e9 100644
--- a/distiller/quantization/__init__.py
+++ b/distiller/quantization/__init__.py
@@ -16,3 +16,8 @@
 
 from .quantizer import Quantizer
 from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer
+from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer
+
+del quantizer
+del range_linear
+del clipped_linear
diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..7711a05cbe186dabc59bc8eabc5116df928e2d98
--- /dev/null
+++ b/distiller/quantization/clipped_linear.py
@@ -0,0 +1,109 @@
+import torch.nn as nn
+
+from .quantizer import Quantizer
+from .q_utils import *
+import logging
+msglogger = logging.getLogger()
+
+###
+# Clipping-based linear quantization (e.g. DoReFa, WRPN)
+###
+
+
+class LinearQuantizeSTE(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, scale_factor, dequantize, inplace):
+        if inplace:
+            ctx.mark_dirty(input)
+        output = linear_quantize(input, scale_factor, inplace)
+        if dequantize:
+            output = linear_dequantize(output, scale_factor, inplace)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        # Straight-through estimator
+        return grad_output, None, None, None
+
+
+class ClippedLinearQuantization(nn.Module):
+    def __init__(self, num_bits, clip_val, dequantize=True, inplace=False):
+        super(ClippedLinearQuantization, self).__init__()
+        self.num_bits = num_bits
+        self.clip_val = clip_val
+        self.scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val)
+        self.dequantize = dequantize
+        self.inplace = inplace
+
+    def forward(self, input):
+        input = clamp(input, 0, self.clip_val, self.inplace)
+        input = LinearQuantizeSTE.apply(input, self.scale_factor, self.dequantize, self.inplace)
+        return input
+
+    def __repr__(self):
+        inplace_str = ', inplace' if self.inplace else ''
+        return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val,
+                                                           inplace_str)
+
+
+class WRPNQuantizer(Quantizer):
+    """
+    Quantizer using the WRPN quantization scheme, as defined in:
+    Mishra et al., WRPN: Wide Reduced-Precision Networks (https://arxiv.org/abs/1709.01134)
+
+    Notes:
+        1. This class does not take care of layer widening as described in the paper
+        2. The paper defines special handling for 1-bit weights which isn't supported here yet
+    """
+    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}):
+        super(WRPNQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
+                                            bits_overrides=bits_overrides, train_with_fp_copy=True)
+
+        def wrpn_quantize_param(param_fp, num_bits):
+            scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1)
+            out = param_fp.clamp(-1, 1)
+            out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
+            return out
+
+        def relu_replace_fn(module, name, qbits_map):
+            bits_acts = qbits_map[name].acts
+            if bits_acts is None:
+                return module
+            return ClippedLinearQuantization(bits_acts, 1, dequantize=True, inplace=module.inplace)
+
+        self.param_quantization_fn = wrpn_quantize_param
+
+        self.replacement_factory[nn.ReLU] = relu_replace_fn
+
+
+class DorefaQuantizer(Quantizer):
+    """
+    Quantizer using the DoReFa scheme, as defined in:
+    Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
+    (https://arxiv.org/abs/1606.06160)
+
+    Notes:
+        1. Gradients quantization not supported yet
+        2. The paper defines special handling for 1-bit weights which isn't supported here yet
+    """
+    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}):
+        super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
+                                              bits_overrides=bits_overrides, train_with_fp_copy=True)
+
+        def dorefa_quantize_param(param_fp, num_bits):
+            scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
+            out = param_fp.tanh()
+            out = out / (2 * out.abs().max()) + 0.5
+            out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
+            out = 2 * out - 1
+            return out
+
+        def relu_replace_fn(module, name, qbits_map):
+            bits_acts = qbits_map[name].acts
+            if bits_acts is None:
+                return module
+            return ClippedLinearQuantization(bits_acts, 1, dequantize=True, inplace=module.inplace)
+
+        self.param_quantization_fn = dorefa_quantize_param
+
+        self.replacement_factory[nn.ReLU] = relu_replace_fn
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index c717d8baed40481c0bf40afe8b6a5a40ab5d5f65..6d1850bb19e6904cc88c64c15a36575be39b9588 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -16,16 +16,43 @@
 
 from collections import namedtuple
 import re
+import copy
 import logging
+import torch
+import torch.nn as nn
 
 msglogger = logging.getLogger()
 
 QBits = namedtuple('QBits', ['acts', 'wts'])
 
+FP_BKP_PREFIX = 'float_'
+
+
+def has_bias(module):
+    return hasattr(module, 'bias') and module.bias is not None
+
+
+def hack_float_backup_parameter(module, name):
+    try:
+        data = dict(module.named_parameters())[name].data
+    except KeyError:
+        raise ValueError('Module has no Parameter named ' + name)
+    module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data))
+    module.__delattr__(name)
+    module.register_buffer(name, torch.zeros_like(data))
+
+
+class _ParamToQuant(object):
+    def __init__(self, module, fp_attr_name, q_attr_name, num_bits):
+        self.module = module
+        self.fp_attr_name = fp_attr_name
+        self.q_attr_name = q_attr_name
+        self.num_bits = num_bits
+
 
 class Quantizer(object):
     r"""
-    Base class for quantizers
+    Base class for quantizers.
 
     Args:
         model (torch.nn.Module): The model to be quantized
@@ -33,54 +60,135 @@ class Quantizer(object):
             Value of None means do not quantize.
         bits_overrides (dict): Dictionary mapping regular expressions of layer name patterns to dictionary with
             values for 'acts' and/or 'wts' to override the default values.
+        quantize_bias (bool): Flag indicating whether to quantize bias (w. same number of bits as weights) or not.
+        train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and
+            floating-point copy, such that the following flow occurs in each training iteration:
+            1. q_weights = quantize(fp_weights)
+            2. Forward through network using q_weights
+            3. In back-prop:
+                3.1 Gradients calculated with respect to q_weights
+                3.2 We also back-prop through the 'quantize' operation from step 1
+            4. Update fp_weights with gradients calculated in step 3.2
     """
-    def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides={}):
+    def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides={}, quantize_bias=False,
+                 train_with_fp_copy=False):
         self.default_qbits = QBits(acts=bits_activations, wts=bits_weights)
+        self.quantize_bias = quantize_bias
 
         self.model = model
 
+        # Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model
+        self.model.quantizer_metadata = {'type': type(self),
+                                         'params': {'bits_activations': bits_activations,
+                                                    'bits_weights': bits_weights,
+                                                    'bits_overrides': copy.deepcopy(bits_overrides)}}
+
         for k, v in bits_overrides.items():
             qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts))
             bits_overrides[k] = qbits
 
         # Prepare explicit mapping from each layer to QBits based on default + overrides
+        regex = None
         if bits_overrides:
             regex_str = ''
             keys_list = list(bits_overrides.keys())
             for pattern in keys_list:
                 regex_str += '(^{0}$)|'.format(pattern)
-            regex_str = regex_str[-1]   # Remove trailing '|'
+            regex_str = regex_str[:-1]   # Remove trailing '|'
             regex = re.compile(regex_str)
 
-            self.layer_qbits_map = {}
-            for layer_full_name, _ in model.named_modules():
-                m = regex.match(layer_full_name)
+        self.module_qbits_map = {}
+        for module_full_name, module in model.named_modules():
+            qbits = self.default_qbits
+            if regex:
+                # Need to account for scenario where model is parallelized with DataParallel, which wraps the original
+                # module with a wrapper module called 'module' :)
+                name_to_match = module_full_name.replace('module.', '', 1)
+                m = regex.match(name_to_match)
                 if m:
                     group_idx = 0
                     groups = m.groups()
                     while groups[group_idx] is None:
                         group_idx += 1
-                    self.layer_qbits_map[layer_full_name] = bits_overrides[keys_list[group_idx]]
-                else:
-                    self.layer_qbits_map[layer_full_name] = self.default_qbits
-        else:
-            self.layer_qbits_map = {layer_full_name: self.default_qbits for layer_full_name, _ in model.named_modules()}
+                    qbits = bits_overrides[keys_list[group_idx]]
+            self._add_qbits_entry(module_full_name, type(module), qbits)
 
+        # Mapping from module type to function generating a replacement module suited for quantization
+        # To be populated by child classes
         self.replacement_factory = {}
+        # Pointer to parameters quantization function, triggered during training process
+        # To be populated by child classes
+        self.param_quantization_fn = None
+
+        self.train_with_fp_copy = train_with_fp_copy
+        self.params_to_quantize = []
+
+    def _add_qbits_entry(self, module_name, module_type, qbits):
+        if module_type not in [nn.Conv2d, nn.Linear]:
+            # For now we support weights quantization only for Conv and FC layers (so, for example, we don't
+            # support quantization of batch norm scale parameters)
+            qbits = QBits(acts=qbits.acts, wts=None)
+        self.module_qbits_map[module_name] = qbits
 
     def prepare_model(self):
-        msglogger.info('Preparing model for quantization')
+        r"""
+        Iterates over the model and replaces modules with their quantized counterparts as defined by
+        self.replacement_factory
+        """
+        msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
         self._pre_process_container(self.model)
 
+        for module_name, module in self.model.named_modules():
+            qbits = self.module_qbits_map[module_name]
+            if qbits.wts is None:
+                continue
+
+            curr_parameters = dict(module.named_parameters())
+            for param_name, param in curr_parameters.items():
+                if param_name.endswith('bias') and not self.quantize_bias:
+                    continue
+                fp_attr_name = param_name
+                if self.train_with_fp_copy:
+                    hack_float_backup_parameter(module, param_name)
+                    fp_attr_name = FP_BKP_PREFIX + param_name
+                self.params_to_quantize.append(_ParamToQuant(module, fp_attr_name, param_name, qbits.wts))
+
+                param_full_name = '.'.join([module_name, param_name])
+                msglogger.info(
+                    "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts))
+
+        msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
+
     def _pre_process_container(self, container, prefix=''):
         # Iterate through model, insert quantization functions as appropriate
         for name, module in container.named_children():
             full_name = prefix + name
             try:
-                new_module = self.replacement_factory[type(module)](module, full_name, self.layer_qbits_map)
+                new_module = self.replacement_factory[type(module)](module, full_name, self.module_qbits_map)
                 msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
                 container._modules[name] = new_module
+
+                # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
+                if len(module._modules) == 0 and len(new_module._modules) > 0:
+                    current_qbits = self.module_qbits_map[full_name]
+                    for sub_module_name, module in new_module.named_modules():
+                        self._add_qbits_entry(full_name + '.' + sub_module_name, type(module), current_qbits)
+                    self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None)
             except KeyError:
+                pass
+
+            if len(module._modules) > 0:
                 # For container we call recursively
-                if len(module._modules) > 0:
-                    self._pre_process_container(module, full_name + '.')
+                self._pre_process_container(module, full_name + '.')
+
+    def quantize_params(self):
+        """
+        Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number
+        of bits for each parameter)
+        """
+        for ptq in self.params_to_quantize:
+            q_param = self.param_quantization_fn(ptq.module.__getattr__(ptq.fp_attr_name), ptq.num_bits)
+            if self.train_with_fp_copy:
+                ptq.module.__setattr__(ptq.q_attr_name, q_param)
+            else:
+                ptq.module.__getattr__(ptq.q_attr_name).data = q_param.data
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 31645cafeac40a2c0473dfa5b9b07a751d6976a0..bca1cb9926773b8511962fcc3b5ef0a2caf0a728 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -161,7 +161,8 @@ class SymmetricLinearQuantizer(Quantizer):
     """
     def __init__(self, model, bits_activations=8, bits_parameters=8):
         super(SymmetricLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
-                                                       bits_weights=bits_parameters)
+                                                       bits_weights=bits_parameters,
+                                                       train_with_fp_copy=False)
 
         def replace_fn(module, name, qbits_map):
             return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts)
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 2f85151c2981e1f7601fad643550b3f6d9f295ff..e4c1d140c51a91a1de604936bbdf5dbe60d0fd92 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -21,6 +21,7 @@ This implements the scheduling of the compression policies.
 from functools import partial
 import logging
 import torch
+from .quantization.quantizer import FP_BKP_PREFIX
 
 msglogger = logging.getLogger()
 
@@ -127,7 +128,6 @@ class CompressionScheduler(object):
                 policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
                                         self.zeros_mask_dict)
 
- 
     def on_epoch_end(self, epoch):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
@@ -135,10 +135,20 @@ class CompressionScheduler(object):
                 meta['current_epoch'] = epoch
                 policy.on_epoch_end(self.model, self.zeros_mask_dict, meta)
 
-
     def apply_mask(self):
         for name, param in self.model.named_parameters():
-            self.zeros_mask_dict[name].apply_mask(param)
+            try:
+                self.zeros_mask_dict[name].apply_mask(param)
+            except KeyError:
+                # Quantizers for training modify some model parameters by adding a prefix
+                # If this is the source of the error, workaround and move on
+                name_parts = name.split('.')
+                if name_parts[-1].startswith(FP_BKP_PREFIX):
+                    name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, '', 1)
+                    name = '.'.join(name_parts)
+                    self.zeros_mask_dict[name].apply_mask(param)
+                else:
+                    raise
 
 
     def state_dict(self):
@@ -149,10 +159,10 @@ class CompressionScheduler(object):
         masks = {}
         for name, masker in self.zeros_mask_dict.items():
             masks[name] = masker.mask
-        state = { 'masks_dict' : masks }
+        state = {'masks_dict': masks,
+                 'parallel_model': isinstance(self.model, torch.nn.DataParallel)}
         return state
 
-
     def load_state_dict(self, state):
         """Loads the scheduler state.
 
@@ -173,6 +183,17 @@ class CompressionScheduler(object):
                 print("\t\t" + k)
             exit(1)
 
+        curr_model_parallel = isinstance(self.model, torch.nn.DataParallel)
+        # Fallback to 'True' for old checkpoints that don't have this attribute, since parallel=True is the
+        # default for create_model
+        loaded_model_parallel = state.get('parallel_model', True)
         for name, mask in self.zeros_mask_dict.items():
+            # DataParallel modules wrap the actual module with a module named "module"...
+            if loaded_model_parallel and not curr_model_parallel:
+                load_name = 'module.' + name
+            elif curr_model_parallel and not loaded_model_parallel:
+                load_name = name.replace('module.', '', 1)
+            else:
+                load_name = name
             masker = self.zeros_mask_dict[name]
-            masker.mask = loaded_masks[name]
+            masker.mask = loaded_masks[load_name]
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 338ec78438c0bdbbc757dff475c34e744445983c..f33e171980a20646649a5b0dc3dd82b9b957f99c 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -80,6 +80,15 @@ import distiller.quantization as quantization
 from models import ALL_MODEL_NAMES, create_model
 
 msglogger = None
+
+
+def float_range(val_str):
+    val = float(val_str)
+    if val < 0 or val >= 1:
+        raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
+    return val
+
+
 parser = argparse.ArgumentParser(description='Distiller image classification model compression')
 parser.add_argument('data', metavar='DIR', help='path to dataset')
 parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
@@ -111,7 +120,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true',
                     help='collect activation statistics (WARNING: this slows down training)')
 parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                     help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
-SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'png_w_params']
+SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
 parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                     help='print a summary of the model, and exit - options: ' +
                     ' | '.join(SUMMARY_CHOICES))
@@ -128,6 +137,9 @@ parser.add_argument('--quantize', action='store_true',
 parser.add_argument('--gpus', metavar='DEV_ID', default=None,
                     help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
 parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
+parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
+parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
+                    help='Portion of training dataset to set aside for validation')
 
 
 def check_pytorch_version():
@@ -142,11 +154,14 @@ def check_pytorch_version():
               "  3. Activate the new environment")
         exit(1)
 
+
 def main():
     global msglogger
     check_pytorch_version()
     args = parser.parse_args()
-    msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name)
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)
 
     # Log various details about the execution environment.  It is sometimes useful
     # to refer to past experiment executions and this information may be useful.
@@ -194,7 +209,7 @@ def main():
 
     # Create the model
     png_summary = args.summary is not None and args.summary.startswith('png')
-    is_parallel = not png_summary   # For PNG summary, parallel graphs are illegible
+    is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible
     model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus)
 
     compression_scheduler = None
@@ -208,18 +223,17 @@ def main():
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
             model, chkpt_file=args.resume)
 
-        if 'resnet' in args.arch and 'cifar' in args.arch:
+        if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch:
             distiller.resnet_cifar_remove_layers(model)
             #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict)
 
     # Define loss function (criterion) and optimizer
     criterion = nn.CrossEntropyLoss().cuda()
-    optimizer = torch.optim.SGD(model.parameters(), args.lr,
+    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
-    msglogger.info("Optimizer (%s): momentum=%s decay=%s", type(optimizer),
-                   args.momentum, args.weight_decay)
-
+    msglogger.info('Optimizer Type: %s', type(optimizer))
+    msglogger.info('Optimizer Args: %s', optimizer.defaults)
 
     # This sample application can be invoked to produce various summary reports.
     if args.summary:
@@ -227,7 +241,7 @@ def main():
         if which_summary.startswith('png'):
             apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params')
         else:
-            distiller.model_summary(model, optimizer, which_summary, args.dataset)
+            distiller.model_summary(model, which_summary, args.dataset)
         exit()
 
     # Load the datasets: the dataset to load is inferred from the model name passed
@@ -235,7 +249,7 @@ def main():
     # substring "_cifar", then cifar10 is used.
     train_loader, val_loader, test_loader, _ = apputils.load_data(
         args.dataset, os.path.expanduser(args.data), args.batch_size,
-        args.workers, args.deterministic)
+        args.workers, args.validation_size, args.deterministic)
     msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                    len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
@@ -276,17 +290,15 @@ def main():
         top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq)
         if args.quantize:
             checkpoint_name = 'quantized'
-            apputils.save_checkpoint(0, args.arch, model, optimizer, best_top1=top1,
-                                     name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name)
+            apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
+                                     name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name,
+                                     dir=msglogger.logdir)
         exit()
 
     if args.compress:
-        # The main use-case for this sample application is CNN compression.  Compression
+        # The main use-case for this sample application is CNN compression. Compression
         # requires a compression schedule configuration file in YAML.
-        source = args.compress
-        msglogger.info("Compression schedule (source=%s)", source)
-        compression_scheduler = distiller.CompressionScheduler(model)
-        distiller.config.fileConfig(model, optimizer, compression_scheduler, args.compress, msglogger)
+        compression_scheduler = distiller.file_config(model, optimizer, args.compress)
 
     for epoch in range(start_epoch, start_epoch + args.epochs):
         # This is the main training loop.
@@ -317,7 +329,8 @@ def main():
         # remember best top1 and save checkpoint
         is_best = top1 > best_top1
         best_top1 = max(top1, best_top1)
-        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best, args.name)
+        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best,
+                                 args.name, msglogger.logdir)
 
     # Finally run results on the test set
     test(test_loader, model, criterion, [pylogger], args.print_freq)
@@ -353,7 +366,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
         input_var = torch.autograd.Variable(inputs)
         target_var = torch.autograd.Variable(target)
 
-        # Execute the forard phase, compute the output and measure loss
+        # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
             compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch)
         output = model(input_var)
diff --git a/examples/quantization/alexnet_bn_base_fp32.yaml b/examples/quantization/alexnet_bn_base_fp32.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..89b57238ac62dcbee7184e5f2c8d94d3d0893a80
--- /dev/null
+++ b/examples/quantization/alexnet_bn_base_fp32.yaml
@@ -0,0 +1,12 @@
+lr_schedulers:
+  training_lr:
+    class: MultiStepLR
+    milestones: [60, 75]
+    gamma: 0.2
+
+policies:
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
diff --git a/examples/quantization/alexnet_bn_dorefa.yaml b/examples/quantization/alexnet_bn_dorefa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4ada4579a90c147466fc1d374f648cd63b667aa6
--- /dev/null
+++ b/examples/quantization/alexnet_bn_dorefa.yaml
@@ -0,0 +1,37 @@
+quantizers:
+  dorefa_quantizer:
+    class: DorefaQuantizer
+    bits_activations: 8
+    bits_weights: 3
+    bits_overrides:
+      features.0:
+        wts: null
+        acts: null
+      features.1:
+        wts: null
+        acts: null
+      classifier.5:
+        wts: null
+        acts: null
+      classifier.6:
+        wts: null
+        acts: null
+
+lr_schedulers:
+  training_lr:
+    class: MultiStepLR
+    milestones: [60, 75]
+    gamma: 0.2
+
+policies:
+    - quantizer:
+        instance_name: dorefa_quantizer
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
+
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
diff --git a/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml b/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2b71df90d4abb743cedf2d55910bc77b0d7bff9b
--- /dev/null
+++ b/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml
@@ -0,0 +1,12 @@
+lr_schedulers:
+  training_lr:
+    class: MultiStepLR
+    milestones: [30, 60, 90, 100]
+    gamma: 0.1
+
+policies:
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
diff --git a/examples/quantization/preact_resnet18_imagenet_dorefa.yaml b/examples/quantization/preact_resnet18_imagenet_dorefa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..baf93d9da3f747596fae15d33673928043561cb1
--- /dev/null
+++ b/examples/quantization/preact_resnet18_imagenet_dorefa.yaml
@@ -0,0 +1,37 @@
+quantizers:
+  dorefa_quantizer:
+    class: DorefaQuantizer
+    bits_activations: 8
+    bits_weights: 3
+    bits_overrides:
+      conv1:
+        wts: null
+        acts: null
+      relu1:
+        wts: null
+        acts: null
+      final_relu:
+        wts: null
+        acts: null
+      fc:
+        wts: null
+        acts: null
+
+lr_schedulers:
+  training_lr:
+    class: MultiStepLR
+    milestones: [30, 60, 90, 100]
+    gamma: 0.1
+
+policies:
+    - quantizer:
+        instance_name: dorefa_quantizer
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
+
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
diff --git a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..792d971ae469b83d2de306c6c510ec9404d8c6eb
--- /dev/null
+++ b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml
@@ -0,0 +1,12 @@
+lr_schedulers:
+  training_lr:
+    class: MultiStepMultiGammaLR
+    milestones: [80, 120, 160]
+    gammas: [0.1, 0.1, 0.2]
+
+policies:
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
diff --git a/examples/quantization/preact_resnet20_cifar_dorefa.yaml b/examples/quantization/preact_resnet20_cifar_dorefa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..94851236f4c9a4c11d335a36178a3ad1a0e02b00
--- /dev/null
+++ b/examples/quantization/preact_resnet20_cifar_dorefa.yaml
@@ -0,0 +1,37 @@
+quantizers:
+  dorefa_quantizer:
+    class: DorefaQuantizer
+    bits_activations: 8
+    bits_weights: 3
+    bits_overrides:
+      conv1:
+        wts: null
+        acts: null
+      layer1.0.pre_relu:
+        wts: null
+        acts: null
+      final_relu:
+        wts: null
+        acts: null
+      fc:
+        wts: null
+        acts: null
+
+lr_schedulers:
+  training_lr:
+    class: MultiStepMultiGammaLR
+    milestones: [80, 120, 160]
+    gammas: [0.1, 0.1, 0.2]
+
+policies:
+    - quantizer:
+        instance_name: dorefa_quantizer
+      starting_epoch: 0
+      ending_epoch: 200
+      frequency: 1
+
+    - lr_scheduler:
+        instance_name: training_lr
+      starting_epoch: 0
+      ending_epoch: 161
+      frequency: 1
diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py
index 328eb450feb67a1d38b2219c995cf2b1130c35cf..767b80c57d532fdaaf78c997dfbf463a1260be57 100755
--- a/examples/word_language_model/main.py
+++ b/examples/word_language_model/main.py
@@ -314,8 +314,7 @@ compression_scheduler = None
 if args.compress:
     # Create a CompressionScheduler and configure it from a YAML schedule file
     source = args.compress
-    compression_scheduler = distiller.CompressionScheduler(model)
-    distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger)
+    compression_scheduler = distiller.config.file_config(model, None, args.compress)
 
 optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                  momentum=args.momentum,
diff --git a/models/cifar10/__init__.py b/models/cifar10/__init__.py
index ffc6ed3897ea336eeb26b4b8629ca79bf71a19eb..3b72572744f7dc61529c05213624c9208f466678 100755
--- a/models/cifar10/__init__.py
+++ b/models/cifar10/__init__.py
@@ -18,3 +18,4 @@
 
 from .simplenet_cifar import *
 from .resnet_cifar import *
+from .preresnet_cifar import *
diff --git a/models/cifar10/preresnet_cifar.py b/models/cifar10/preresnet_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8aa50c3f53ac3191e317101246385aa5cfff313
--- /dev/null
+++ b/models/cifar10/preresnet_cifar.py
@@ -0,0 +1,208 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Pre-Activation ResNet for CIFAR10
+
+Pre-Activation ResNet for CIFAR10, based on "Identity Mappings in Deep Residual Networks".
+This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
+changes for pre-activation and the 10-class Cifar-10 dataset.
+This ResNet also has layer gates, to be able to dynamically remove layers.
+
+@article{
+  He2016,
+  author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
+  title = {Identity Mappings in Deep Residual Networks},
+  journal = {arXiv preprint arXiv:1603.05027},
+  year = {2016}
+}
+"""
+import torch.nn as nn
+import math
+
+__all__ = ['preact_resnet20_cifar', 'preact_resnet32_cifar', 'preact_resnet44_cifar', 'preact_resnet56_cifar',
+           'preact_resnet110_cifar', 'preact_resnet20_cifar_conv_ds', 'preact_resnet32_cifar_conv_ds',
+           'preact_resnet44_cifar_conv_ds', 'preact_resnet56_cifar_conv_ds', 'preact_resnet110_cifar_conv_ds']
+
+NUM_CLASSES = 10
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class PreactBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None, preact_downsample=True):
+        super(PreactBasicBlock, self).__init__()
+        self.block_gates = block_gates
+        self.pre_bn = nn.BatchNorm2d(inplanes)
+        self.pre_relu = nn.ReLU(inplace=False)  # To enable layer removal inplace must be False
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=False)
+        self.conv2 = conv3x3(planes, planes)
+        self.downsample = downsample
+        self.stride = stride
+        self.preact_downsample = preact_downsample
+
+    def forward(self, x):
+        need_preact = self.block_gates[0] or self.block_gates[1] or self.downsample and self.preact_downsample
+        if need_preact:
+            preact = self.pre_bn(x)
+            preact = self.pre_relu(preact)
+            out = preact
+        else:
+            preact = out = x
+
+        if self.block_gates[0]:
+            out = self.conv1(out)
+            out = self.bn(out)
+            out = self.relu(out)
+
+        if self.block_gates[1]:
+            out = self.conv2(out)
+
+        if self.downsample is not None:
+            if self.preact_downsample:
+                residual = self.downsample(preact)
+            else:
+                residual = self.downsample(x)
+        else:
+            residual = x
+
+        out += residual
+
+        return out
+
+
+class PreactResNetCifar(nn.Module):
+    def __init__(self, block, layers, num_classes=NUM_CLASSES, conv_downsample=False):
+        self.nlayers = 0
+        # Each layer manages its own gates
+        self.layer_gates = []
+        for layer in range(3):
+            # For each of the 3 layers, create block gates: each block has two layers
+            self.layer_gates.append([])  # [True, True] * layers[layer])
+            for blk in range(layers[layer]):
+                self.layer_gates[layer].append([True, True])
+
+        self.inplanes = 16  # 64
+        super(PreactResNetCifar, self).__init__()
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0],
+                                       conv_downsample=conv_downsample)
+        self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2,
+                                       conv_downsample=conv_downsample)
+        self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2,
+                                       conv_downsample=conv_downsample)
+        self.final_bn = nn.BatchNorm2d(64 * block.expansion)
+        self.final_relu = nn.ReLU(inplace=True)
+        self.avgpool = nn.AvgPool2d(8, stride=1)
+        self.fc = nn.Linear(64 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False):
+        downsample = None
+        outplanes = planes * block.expansion
+        if stride != 1 or self.inplanes != outplanes:
+            if conv_downsample:
+                downsample = nn.Conv2d(self.inplanes, outplanes,
+                                       kernel_size=1, stride=stride, bias=False)
+            else:
+                # Identity downsample uses strided average pooling + padding instead of convolution
+                pad_amount = int(self.inplanes / 2)
+                downsample = nn.Sequential(
+                    nn.AvgPool2d(2),
+                    nn.ConstantPad3d((0, 0, 0, 0, pad_amount, pad_amount), 0)
+                )
+
+        layers = []
+        layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample, conv_downsample))
+        self.inplanes = outplanes
+        for i in range(1, blocks):
+            layers.append(block(layer_gates[i], self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.final_bn(x)
+        x = self.final_relu(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        return x
+
+
+def preact_resnet20_cifar(**kwargs):
+    model = PreactResNetCifar(PreactBasicBlock, [3, 3, 3], **kwargs)
+    return model
+
+
+def preact_resnet32_cifar(**kwargs):
+    model = PreactResNetCifar(PreactBasicBlock, [5, 5, 5], **kwargs)
+    return model
+
+
+def preact_resnet44_cifar(**kwargs):
+    model = PreactResNetCifar(PreactBasicBlock, [7, 7, 7], **kwargs)
+    return model
+
+
+def preact_resnet56_cifar(**kwargs):
+    model = PreactResNetCifar(PreactBasicBlock, [9, 9, 9], **kwargs)
+    return model
+
+
+def preact_resnet110_cifar(**kwargs):
+    model = PreactResNetCifar(PreactBasicBlock, [18, 18, 18], **kwargs)
+    return model
+
+
+def preact_resnet20_cifar_conv_ds(**kwargs):
+    return preact_resnet20_cifar(conv_downsample=True)
+
+
+def preact_resnet32_cifar_conv_ds(**kwargs):
+    return preact_resnet32_cifar(conv_downsample=True)
+
+
+def preact_resnet44_cifar_conv_ds(**kwargs):
+    return preact_resnet44_cifar(conv_downsample=True)
+
+
+def preact_resnet56_cifar_conv_ds(**kwargs):
+    return preact_resnet56_cifar(conv_downsample=True)
+
+
+def preact_resnet110_cifar_conv_ds(**kwargs):
+    return preact_resnet110_cifar(conv_downsample=True)
diff --git a/models/cifar10/resnet_cifar.py b/models/cifar10/resnet_cifar.py
index 44d2ba2cbb16a7ae51c4a3757051556a013e48c0..e9ce4e55ab026ba545223d42bce061328f4105d3 100755
--- a/models/cifar10/resnet_cifar.py
+++ b/models/cifar10/resnet_cifar.py
@@ -53,16 +53,13 @@ class BasicBlock(nn.Module):
 
     def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None):
         super(BasicBlock, self).__init__()
-        #self.layer_id = layer_id
-        #self.block_id = block_id
         self.block_gates = block_gates
         self.conv1 = conv3x3(inplanes, planes, stride)
         self.bn1 = nn.BatchNorm2d(planes)
-        # This change is required for layer removal
-        #self.relu = nn.ReLU(inplace=True)
-        self.relu = nn.ReLU(inplace=False)
+        self.relu1 = nn.ReLU(inplace=False)  # To enable layer removal inplace must be False
         self.conv2 = conv3x3(planes, planes)
         self.bn2 = nn.BatchNorm2d(planes)
+        self.relu2 = nn.ReLU(inplace=False)
         self.downsample = downsample
         self.stride = stride
 
@@ -72,7 +69,7 @@ class BasicBlock(nn.Module):
         if self.block_gates[0]:
             out = self.conv1(x)
             out = self.bn1(out)
-            out = self.relu(out)
+            out = self.relu1(out)
 
         if self.block_gates[1]:
             out = self.conv2(out)
@@ -82,7 +79,7 @@ class BasicBlock(nn.Module):
             residual = self.downsample(x)
 
         out += residual
-        out = self.relu(out)
+        out = self.relu2(out)
 
         return out
 
@@ -95,27 +92,19 @@ class ResNetCifar(nn.Module):
         self.layer_gates = []
         for layer in range(3):
             # For each of the 3 layers, create block gates: each block has two layers
-            self.layer_gates.append([]) # [True, True] * layers[layer])
+            self.layer_gates.append([])  # [True, True] * layers[layer])
             for blk in range(layers[layer]):
                 self.layer_gates[layer].append([True, True])
 
-        self.inplanes = 16 # 64
+        self.inplanes = 16  # 64
         super(ResNetCifar, self).__init__()
-        #self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
-        #self.bn1 = nn.BatchNorm2d(64)
         self.bn1 = nn.BatchNorm2d(self.inplanes)
         self.relu = nn.ReLU(inplace=True)
-        #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
         self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0])
         self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2)
         self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2)
-        # self.layer1 = self._make_layer(block, 64, layers[0])
-        # self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
-        # self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
-        #self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
         self.avgpool = nn.AvgPool2d(8, stride=1)
-        #self.fc = nn.Linear(512 * block.expansion, num_classes)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
         for m in self.modules():
@@ -147,12 +136,10 @@ class ResNetCifar(nn.Module):
         x = self.conv1(x)
         x = self.bn1(x)
         x = self.relu(x)
-        #x = self.maxpool(x)
 
         x = self.layer1(x)
         x = self.layer2(x)
         x = self.layer3(x)
-        #x = self.layer4(x)
 
         x = self.avgpool(x)
         x = x.view(x.size(0), -1)
diff --git a/models/imagenet/__init__.py b/models/imagenet/__init__.py
index ff4d9e2840fa9103eb5c8e1915e251ab32e48f62..300ebd50ff354f6555d51a224c7f6f4c91491b36 100755
--- a/models/imagenet/__init__.py
+++ b/models/imagenet/__init__.py
@@ -17,3 +17,5 @@
 """This package contains ImageNet image classification models not found in torchvision"""
 
 from .mobilenet import *
+from .preresnet_imagenet import *
+from .alexnet_batchnorm import *
diff --git a/models/imagenet/alexnet_batchnorm.py b/models/imagenet/alexnet_batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd444529580421aa9a1c7ed7e5548db9d9b68fa
--- /dev/null
+++ b/models/imagenet/alexnet_batchnorm.py
@@ -0,0 +1,92 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+AlexNet model with batch-norm layers.
+Model configuration based on the AlexNet DoReFa example in TensorPack:
+https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py
+
+Code based on the AlexNet PyTorch sample, with the required changes.
+"""
+
+import math
+import torch.nn as nn
+
+__all__ = ['AlexNetBN', 'alexnet_bn']
+
+
+class AlexNetBN(nn.Module):
+
+    def __init__(self, num_classes=1000):
+        super(AlexNetBN, self).__init__()
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 96, kernel_size=12, stride=4),                           # conv0 (224x224x3) -> (54x54x96)
+            nn.ReLU(inplace=True),
+            nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2, bias=False),   # conv1 (54x54x96)  -> (54x54x256)
+            nn.BatchNorm2d(256, eps=1e-4, momentum=0.9),                          # bn1   (54x54x256)
+            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),                # pool1 (54x54x256) -> (27x27x256)
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(256, 384, kernel_size=3, padding=1, bias=False),            # conv2 (27x27x256) -> (27x27x384)
+            nn.BatchNorm2d(384, eps=1e-4, momentum=0.9),                          # bn2   (27x27x384)
+            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),                     # pool2 (27x27x384) -> (14x14x384)
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False),  # conv3 (14x14x384) -> (14x14x384)
+            nn.BatchNorm2d(384, eps=1e-4, momentum=0.9),                          # bn3   (14x14x384)
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False),  # conv4 (14x14x384) -> (14x14x256)
+            nn.BatchNorm2d(256, eps=1e-4, momentum=0.9),                          # bn4   (14x14x256)
+            nn.MaxPool2d(kernel_size=3, stride=2),                                # pool4 (14x14x256) -> (6x6x256)
+            nn.ReLU(inplace=True),
+        )
+        self.classifier = nn.Sequential(
+            nn.Linear(256 * 6 * 6, 4096, bias=False),       # fc0
+            nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9),   # bnfc0
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, 4096, bias=False),              # fc1
+            nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9),   # bnfc1
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, num_classes),                   # fct
+        )
+
+        for m in self.modules():
+            if isinstance(m, (nn.Conv2d, nn.Linear)):
+                fan_in, k_size = (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) if isinstance(m, nn.Conv2d) \
+                    else (m.in_features, 1)
+                n = k_size * fan_in
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+                if hasattr(m, 'bias') and m.bias is not None:
+                    m.bias.data.fill_(0)
+            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), 256 * 6 * 6)
+        x = self.classifier(x)
+        return x
+
+
+def alexnet_bn(**kwargs):
+    r"""AlexNet model with batch-norm layers.
+    Model configuration based on the AlexNet DoReFa example in `TensorPack
+    <https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py>`
+    """
+    model = AlexNetBN(**kwargs)
+    return model
diff --git a/models/imagenet/preresnet_imagenet.py b/models/imagenet/preresnet_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..412ee99972b70ed1621ead4c97c5f4142eeacafe
--- /dev/null
+++ b/models/imagenet/preresnet_imagenet.py
@@ -0,0 +1,231 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Pre-Activation ResNet for ImageNet
+
+Pre-Activation ResNet for ImageNet, based on "Identity Mappings in Deep Residual Networks".
+This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate changes for pre-activation.
+
+@article{
+  He2016,
+  author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
+  title = {Identity Mappings in Deep Residual Networks},
+  journal = {arXiv preprint arXiv:1603.05027},
+  year = {2016}
+}
+"""
+
+import torch.nn as nn
+import math
+
+
+__all__ = ['PreactResNet', 'preact_resnet18', 'preact_resnet34', 'preact_resnet50', 'preact_resnet101',
+           'preact_resnet152']
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class PreactBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True):
+        super(PreactBasicBlock, self).__init__()
+        self.pre_bn = self.pre_relu = None
+        if preactivate:
+            self.pre_bn = nn.BatchNorm2d(inplanes)
+            self.pre_relu = nn.ReLU(inplace=True)
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1_2 = nn.BatchNorm2d(planes)
+        self.relu1_2 = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.downsample = downsample
+        self.stride = stride
+        self.preactivate = preactivate
+
+    def forward(self, x):
+        if self.preactivate:
+            preact = self.pre_bn(x)
+            preact = self.pre_relu(preact)
+        else:
+            preact = x
+
+        out = self.conv1(preact)
+        out = self.bn1_2(out)
+        out = self.relu1_2(out)
+        out = self.conv2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(preact)
+        else:
+            residual = x
+
+        out += residual
+
+        return out
+
+
+class PreactBottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True):
+        super(PreactBottleneck, self).__init__()
+        self.pre_bn = self.pre_relu = None
+        if preactivate:
+            self.pre_bn = nn.BatchNorm2d(inplanes)
+            self.pre_relu = nn.ReLU(inplace=True)
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1_2 = nn.BatchNorm2d(planes)
+        self.relu1_2 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2_3 = nn.BatchNorm2d(planes)
+        self.relu2_3 = nn.ReLU(inplace=True)
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.downsample = downsample
+        self.stride = stride
+        self.preactivate = preactivate
+
+    def forward(self, x):
+        if self.preactivate:
+            preact = self.pre_bn(x)
+            preact = self.pre_relu(preact)
+        else:
+            preact = x
+
+        out = self.conv1(preact)
+        out = self.bn1_2(out)
+        out = self.relu1_2(out)
+
+        out = self.conv2(out)
+        out = self.bn2_3(out)
+        out = self.relu2_3(out)
+
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(preact)
+        else:
+            residual = x
+
+        out += residual
+
+        return out
+
+
+class PreactResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000):
+        self.inplanes = 64
+        super(PreactResNet, self).__init__()
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.final_bn = nn.BatchNorm2d(512 * block.expansion)
+        self.final_relu = nn.ReLU(inplace=True)
+        self.avgpool = nn.AvgPool2d(7, stride=1)
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+            )
+
+        # On the first residual block in the first residual layer we don't pre-activate,
+        # because we take care of that (+ maxpool) after the initial conv layer
+        preactivate_first = stride != 1
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, preactivate_first))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.final_bn(x)
+        x = self.final_relu(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        return x
+
+
+def preact_resnet18(**kwargs):
+    """Constructs a ResNet-18 model.
+    """
+    model = PreactResNet(PreactBasicBlock, [2, 2, 2, 2], **kwargs)
+    return model
+
+
+def preact_resnet34(**kwargs):
+    """Constructs a ResNet-34 model.
+    """
+    model = PreactResNet(PreactBasicBlock, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def preact_resnet50(**kwargs):
+    """Constructs a ResNet-50 model.
+    """
+    model = PreactResNet(PreactBottleneck, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def preact_resnet101(**kwargs):
+    """Constructs a ResNet-101 model.
+    """
+    model = PreactResNet(PreactBottleneck, [3, 4, 23, 3], **kwargs)
+    return model
+
+
+def preact_resnet152(**kwargs):
+    """Constructs a ResNet-152 model.
+    """
+    model = PreactResNet(PreactBottleneck, [3, 8, 36, 3], **kwargs)
+    return model
diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e63121c0122e246194738e14aa34eafa5ac05af
--- /dev/null
+++ b/tests/test_learning_rate.py
@@ -0,0 +1,50 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+
+import torch
+from torch.optim import Optimizer
+from distiller.learning_rate import MultiStepMultiGammaLR
+
+
+def test_multi_step_multi_gamma_lr():
+    dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True)
+    dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1})
+    lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2])
+    expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2]
+    expected_lrs = [0.1 * gamma for gamma in expected_gammas]
+    assert lr_sched.multiplicative_gammas == expected_gammas
+    lr_sched.step(0)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0]
+    lr_sched.step(15)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0]
+    lr_sched.step(30)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1]
+    lr_sched.step(33)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1]
+    lr_sched.step(60)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2]
+    lr_sched.step(79)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2]
+    lr_sched.step(80)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3]
+    lr_sched.step(100)
+    assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3]
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 50ce5753da84840bf1a8a89859ae67f3d950eec7..5bcb82da417a4035d222044a6fc2c8df03fcf393 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -80,10 +80,10 @@ def test_connectivity():
     preds = g.predecessors(op, 2)
     assert preds == ['layer1.0.bn2', 'relu']
 
-    op = g.find_op('layer1.0.relu1')
+    op = g.find_op('layer1.0.relu2')
     assert op is not None
     succs = g.successors(op, 4)
-    assert succs == ['layer1.1.bn1', 'layer1.1.relu1']
+    assert succs == ['layer1.1.bn1', 'layer1.1.relu2']
 
     preds = g.predecessors(g.find_op('bn1'), 10)
     assert preds == []