diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 0a82c11628e30f0348bb33a26c1e14dd48406301..46c3b1314f746f62c926253e398466f3e031467e 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -79,190 +79,21 @@ import apputils
 from distiller.data_loggers import *
 import distiller.quantization as quantization
 from models import ALL_MODEL_NAMES, create_model
+import parser
 
 
 # Logger handle
 msglogger = None
 
 
-def float_range(val_str):
-    val = float(val_str)
-    if val < 0 or val >= 1:
-        raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
-    return val
-
-
-parser = argparse.ArgumentParser(description='Distiller image classification model compression')
-parser.add_argument('data', metavar='DIR', help='path to dataset')
-parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
-                    choices=ALL_MODEL_NAMES,
-                    help='model architecture: ' +
-                    ' | '.join(ALL_MODEL_NAMES) +
-                    ' (default: resnet18)')
-parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
-                    help='number of data loading workers (default: 4)')
-parser.add_argument('--epochs', default=90, type=int, metavar='N',
-                    help='number of total epochs to run')
-parser.add_argument('-b', '--batch-size', default=256, type=int,
-                    metavar='N', help='mini-batch size (default: 256)')
-parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
-                    metavar='LR', help='initial learning rate')
-parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
-                    help='momentum')
-parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
-                    metavar='W', help='weight decay (default: 1e-4)')
-parser.add_argument('--print-freq', '-p', default=10, type=int,
-                    metavar='N', help='print frequency (default: 10)')
-parser.add_argument('--resume', default='', type=str, metavar='PATH',
-                    help='path to latest checkpoint (default: none)')
-parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
-                    help='evaluate model on validation set')
-parser.add_argument('--pretrained', dest='pretrained', action='store_true',
-                    help='use pre-trained model')
-parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
-                    help='collect activation statistics (WARNING: this slows down training)')
-parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
-                    help='print masks sparsity table at end of each epoch')
-parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
-                    help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
-SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
-parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
-                    help='print a summary of the model, and exit - options: ' +
-                    ' | '.join(SUMMARY_CHOICES))
-parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
-                    help='configuration file for pruning the model (default is to use hard-coded schedule)')
-parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
-                    help='test the sensitivity of layers to pruning')
-parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
-                    help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
-                    'This is equivalent to creating sensitivities = np.arange(start, stop, step)')
-parser.add_argument('--extras', default=None, type=str,
-                    help='file with extra configuration information')
-parser.add_argument('--deterministic', '--det', action='store_true',
-                    help='Ensure deterministic execution for re-producible results.')
-parser.add_argument('--gpus', metavar='DEV_ID', default=None,
-                    help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
-parser.add_argument('--cpu', action='store_true', default=False,
-                    help='Use CPU only. \n'
-                    'Flag not set => uses GPUs according to the --gpus flag value.'
-                    'Flag set => overrides the --gpus flag')
-parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
-parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
-parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
-                    help='Portion of training dataset to set aside for validation')
-parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
-parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
-parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
-                    help='Display the confusion matrix')
-parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
-                    help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
-parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
-                    help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)')
-parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
-                    help='number of best scores to track and report (default: 1)')
-parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
-                    help='Load a model without DataParallel wrapping it')
-
-str_to_quant_mode_map = {'sym': quantization.LinearQuantMode.SYMMETRIC,
-                         'asym_s': quantization.LinearQuantMode.ASYMMETRIC_SIGNED,
-                         'asym_u': quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED}
-
-
-def linear_quant_mode_str(val_str):
-    try:
-        return str_to_quant_mode_map[val_str]
-    except KeyError:
-        raise argparse.ArgumentError('Must be one of {0} (received {1})'.format(list(str_to_quant_mode_map.keys()),
-                                                                                val_str))
-
-
-quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
-                                        '("post-training quantization)')
-quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
-                         help='Apply linear quantization to model before evaluation. Applicable only if'
-                              '--evaluate is also set')
-quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
-                         help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
-quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
-                         help='Number of bits for quantization of activations')
-quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
-                         help='Number of bits for quantization of weights')
-quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
-                         help='Number of bits for quantization of the accumulator')
-quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true',
-                         help='Enable clipping of activations using min/max values averaging over batch')
-quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
-                         help='List of layer names for which not to clip activations. Applicable only if '
-                              '--qe-clip-acts is also set')
-quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true',
-                         help='Enable per-channel quantization of weights (per output channel)')
-
-distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)
-
-
-def check_pytorch_version():
-    if torch.__version__ < '0.4.0':
-        print("\nNOTICE:")
-        print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
-              "PyTorch API changes which are not backward-compatible.\n"
-              "Please install PyTorch 0.4.0 or its derivative.\n"
-              "If you are using a virtual environment, do not forget to update it:\n"
-              "  1. Deactivate the old environment\n"
-              "  2. Install the new environment\n"
-              "  3. Activate the new environment")
-        exit(1)
-
-
-def create_activation_stats_collectors(model, collection_phase):
-    """Create objects that collect activation statistics.
-
-    This is a utility function that creates two collectors:
-    1. Fine-grade sparsity levels of the activations
-    2. L1-magnitude of each of the activation channels
-
-    Args:
-        model - the model on which we want to collect statistics
-        phase - the statistics collection phase which is either "train" (for training),
-                or "valid" (for validation)
-
-    WARNING! Enabling activation statsitics collection will significantly slow down training!
-    """
-    class missingdict(dict):
-        """This is a little trick to prevent KeyError"""
-        def __missing__(self, key):
-            return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
-
-    distiller.utils.assign_layer_fq_names(model)
-
-    activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()}
-    if collection_phase is None:
-        return activations_collectors
-    collectors = missingdict({
-        "sparsity":      SummaryActivationStatsCollector(model, "sparsity",
-                                                         lambda t: 100 * distiller.utils.sparsity(t)),
-        "l1_channels":   SummaryActivationStatsCollector(model, "l1_channels",
-                                                         distiller.utils.activation_channels_l1),
-        "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
-                                                         distiller.utils.activation_channels_apoz),
-        "records":       RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
-    })
-    activations_collectors[collection_phase] = collectors
-    return activations_collectors
-
-
-def save_collectors_data(collectors, directory):
-    """Utility function that saves all activation statistics to Excel workbooks
-    """
-    for name, collector in collectors.items():
-        workbook = os.path.join(directory, name)
-        msglogger.info("Generating {}".format(workbook))
-        collector.to_xlsx(workbook)
-
-
 def main():
     global msglogger
-    check_pytorch_version()
-    args = parser.parse_args()
+
+    # Parse arguments
+    prsr = parser.getParser()
+    distiller.knowledge_distillation.add_distillation_args(prsr, ALL_MODEL_NAMES, True)
+    args = prsr.parse_args()
+
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
     msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)
@@ -369,7 +200,7 @@ def main():
     msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                    len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
-    activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
+    activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)
 
     if args.sensitivity is not None:
         sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
@@ -784,7 +615,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie
         loggers = [loggers]
     test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
                        loggers=loggers, args=args,
-                       activations_collectors=create_activation_stats_collectors(model, None))
+                       activations_collectors=create_activation_stats_collectors(model))
     which_params = [param_name for param_name, _ in model.named_parameters()]
     sensitivity = distiller.perform_sensitivity_analysis(model,
                                                          net_params=which_params,
@@ -823,8 +654,65 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
     ADC.do_adc(model, args.dataset, args.arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
 
 
+def create_activation_stats_collectors(model, *phases):
+    """Create objects that collect activation statistics.
+
+    This is a utility function that creates two collectors:
+    1. Fine-grade sparsity levels of the activations
+    2. L1-magnitude of each of the activation channels
+
+    Args:
+        model - the model on which we want to collect statistics
+        phases - the statistics collection phases: train, valid, and/or test
+
+    WARNING! Enabling activation statsitics collection will significantly slow down training!
+    """
+    class missingdict(dict):
+        """This is a little trick to prevent KeyError"""
+        def __missing__(self, key):
+            return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
+
+    distiller.utils.assign_layer_fq_names(model)
+
+    genCollectors = lambda: missingdict({
+        "sparsity":      SummaryActivationStatsCollector(model, "sparsity",
+                                                         lambda t: 100 * distiller.utils.sparsity(t)),
+        "l1_channels":   SummaryActivationStatsCollector(model, "l1_channels",
+                                                         distiller.utils.activation_channels_l1),
+        "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
+                                                         distiller.utils.activation_channels_apoz),
+        "records":       RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
+    })
+
+    return {k: (genCollectors() if k in phases else missingdict())
+            for k in ('train', 'valid', 'test')}
+
+
+def save_collectors_data(collectors, directory):
+    """Utility function that saves all activation statistics to Excel workbooks
+    """
+    for name, collector in collectors.items():
+        workbook = os.path.join(directory, name)
+        msglogger.info("Generating {}".format(workbook))
+        collector.to_xlsx(workbook)
+
+
+def check_pytorch_version():
+    if torch.__version__ < '0.4.0':
+        print("\nNOTICE:")
+        print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
+              "PyTorch API changes which are not backward-compatible.\n"
+              "Please install PyTorch 0.4.0 or its derivative.\n"
+              "If you are using a virtual environment, do not forget to update it:\n"
+              "  1. Deactivate the old environment\n"
+              "  2. Install the new environment\n"
+              "  3. Activate the new environment")
+        exit(1)
+
+
 if __name__ == '__main__':
     try:
+        check_pytorch_version()
         main()
     except KeyboardInterrupt:
         print("\n-- KeyboardInterrupt --")
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..f62174423bca586452cae1cb332ddde35786e57e
--- /dev/null
+++ b/examples/classifier_compression/parser.py
@@ -0,0 +1,140 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+
+import distiller
+import models
+
+
+SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
+
+def getParser():
+    parser = argparse.ArgumentParser(description='Distiller image classification model compression')
+    parser.add_argument('data', metavar='DIR', help='path to dataset')
+    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(),
+                        choices=models.ALL_MODEL_NAMES,
+                        help='model architecture: ' +
+                        ' | '.join(models.ALL_MODEL_NAMES) +
+                        ' (default: resnet18)')
+    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+                        help='number of data loading workers (default: 4)')
+    parser.add_argument('--epochs', default=90, type=int, metavar='N',
+                        help='number of total epochs to run')
+    parser.add_argument('-b', '--batch-size', default=256, type=int,
+                        metavar='N', help='mini-batch size (default: 256)')
+    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+                        metavar='LR', help='initial learning rate')
+    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+                        help='momentum')
+    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
+                        metavar='W', help='weight decay (default: 1e-4)')
+    parser.add_argument('--print-freq', '-p', default=10, type=int,
+                        metavar='N', help='print frequency (default: 10)')
+    parser.add_argument('--resume', default='', type=str, metavar='PATH',
+                        help='path to latest checkpoint (default: none)')
+    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+                        help='evaluate model on validation set')
+    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
+                        help='use pre-trained model')
+    parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
+                        # choices=["train", "valid", "test"]
+                        help='collect activation statistics on phases: train, valid, and/or test'
+                        ' (WARNING: this slows down training)')
+    parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
+                        help='print masks sparsity table at end of each epoch')
+    parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
+                        help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
+    parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES,
+                        help='print a summary of the model, and exit - options: ' +
+                        ' | '.join(SUMMARY_CHOICES))
+    parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
+                        help='configuration file for pruning the model (default is to use hard-coded schedule)')
+    parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(),
+                        help='test the sensitivity of layers to pruning')
+    parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
+                        help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
+                        'This is equivalent to creating sensitivities = np.arange(start, stop, step)')
+    parser.add_argument('--extras', default=None, type=str,
+                        help='file with extra configuration information')
+    parser.add_argument('--deterministic', '--det', action='store_true',
+                        help='Ensure deterministic execution for re-producible results.')
+    parser.add_argument('--gpus', metavar='DEV_ID', default=None,
+                        help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
+    parser.add_argument('--cpu', action='store_true', default=False,
+                        help='Use CPU only. \n'
+                        'Flag not set => uses GPUs according to the --gpus flag value.'
+                        'Flag set => overrides the --gpus flag')
+    parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
+    parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
+    parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
+                        help='Portion of training dataset to set aside for validation')
+    parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
+    parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
+    parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
+                        help='Display the confusion matrix')
+    parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
+                        help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
+    parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
+                        help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)')
+    parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
+                        help='number of best scores to track and report (default: 1)')
+    parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
+                        help='Load a model without DataParallel wrapping it')
+
+    str_to_quant_mode_map = {
+                          'sym': distiller.quantization.LinearQuantMode.SYMMETRIC,
+                          'asym_s': distiller.quantization.LinearQuantMode.ASYMMETRIC_SIGNED,
+                          'asym_u': distiller.quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED,
+                          }
+
+    def linear_quant_mode_str(val_str):
+        try:
+            return str_to_quant_mode_map[val_str]
+        except KeyError:
+            raise argparse.ArgumentError(
+                    'Must be one of {0} (received {1})'.format(
+                      list(str_to_quant_mode_map), val_str))
+
+    quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
+                                            '("post-training quantization)')
+    quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
+                             help='Apply linear quantization to model before evaluation. Applicable only if'
+                                 '--evaluate is also set')
+    quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
+                             help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
+    quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
+                             help='Number of bits for quantization of activations')
+    quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
+                             help='Number of bits for quantization of weights')
+    quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
+                             help='Number of bits for quantization of the accumulator')
+    quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true',
+                             help='Enable clipping of activations using min/max values averaging over batch')
+    quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
+                             help='List of layer names for which not to clip activations. Applicable only if '
+                                  '--qe-clip-acts is also set')
+    quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true',
+                             help='Enable per-channel quantization of weights (per output channel)')
+
+    return parser
+
+
+def float_range(val_str):
+    val = float(val_str)
+    if val < 0 or val >= 1:
+        raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
+    return val
diff --git a/models/__init__.py b/models/__init__.py
index 93545ea56c0de532d441eb4bbb86efcbbe160e58..e7c19de2c87daa921d43a2dded9de698589a2b7b 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -39,7 +39,7 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
                              if name.islower() and not name.startswith("__")
                              and callable(cifar10_models.__dict__[name]))
 
-ALL_MODEL_NAMES = sorted(set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))
+ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
 
 
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):