diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 0a82c11628e30f0348bb33a26c1e14dd48406301..46c3b1314f746f62c926253e398466f3e031467e 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -79,190 +79,21 @@ import apputils from distiller.data_loggers import * import distiller.quantization as quantization from models import ALL_MODEL_NAMES, create_model +import parser # Logger handle msglogger = None -def float_range(val_str): - val = float(val_str) - if val < 0 or val >= 1: - raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str)) - return val - - -parser = argparse.ArgumentParser(description='Distiller image classification model compression') -parser.add_argument('data', metavar='DIR', help='path to dataset') -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - choices=ALL_MODEL_NAMES, - help='model architecture: ' + - ' | '.join(ALL_MODEL_NAMES) + - ' (default: resnet18)') -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', help='mini-batch size (default: 256)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') -parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None, - help='collect activation statistics (WARNING: this slows down training)') -parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, - help='print masks sparsity table at end of each epoch') -parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, - help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] -parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, - help='print a summary of the model, and exit - options: ' + - ' | '.join(SUMMARY_CHOICES)) -parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', - help='configuration file for pruning the model (default is to use hard-coded schedule)') -parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], - help='test the sensitivity of layers to pruning') -parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05], - help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n' - 'This is equivalent to creating sensitivities = np.arange(start, stop, step)') -parser.add_argument('--extras', default=None, type=str, - help='file with extra configuration information') -parser.add_argument('--deterministic', '--det', action='store_true', - help='Ensure deterministic execution for re-producible results.') -parser.add_argument('--gpus', metavar='DEV_ID', default=None, - help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') -parser.add_argument('--cpu', action='store_true', default=False, - help='Use CPU only. \n' - 'Flag not set => uses GPUs according to the --gpus flag value.' - 'Flag set => overrides the --gpus flag') -parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') -parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints') -parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, - help='Portion of training dataset to set aside for validation') -parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK') -parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK') -parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', - help='Display the confusion matrix') -parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, - help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)') -parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, - help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)') -parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int, - help='number of best scores to track and report (default: 1)') -parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, - help='Load a model without DataParallel wrapping it') - -str_to_quant_mode_map = {'sym': quantization.LinearQuantMode.SYMMETRIC, - 'asym_s': quantization.LinearQuantMode.ASYMMETRIC_SIGNED, - 'asym_u': quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED} - - -def linear_quant_mode_str(val_str): - try: - return str_to_quant_mode_map[val_str] - except KeyError: - raise argparse.ArgumentError('Must be one of {0} (received {1})'.format(list(str_to_quant_mode_map.keys()), - val_str)) - - -quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' - '("post-training quantization)') -quant_group.add_argument('--quantize-eval', '--qe', action='store_true', - help='Apply linear quantization to model before evaluation. Applicable only if' - '--evaluate is also set') -quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym', - help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys())) -quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS', - help='Number of bits for quantization of activations') -quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS', - help='Number of bits for quantization of weights') -quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS', - help='Number of bits for quantization of the accumulator') -quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true', - help='Enable clipping of activations using min/max values averaging over batch') -quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[], - help='List of layer names for which not to clip activations. Applicable only if ' - '--qe-clip-acts is also set') -quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true', - help='Enable per-channel quantization of weights (per output channel)') - -distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True) - - -def check_pytorch_version(): - if torch.__version__ < '0.4.0': - print("\nNOTICE:") - print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to " - "PyTorch API changes which are not backward-compatible.\n" - "Please install PyTorch 0.4.0 or its derivative.\n" - "If you are using a virtual environment, do not forget to update it:\n" - " 1. Deactivate the old environment\n" - " 2. Install the new environment\n" - " 3. Activate the new environment") - exit(1) - - -def create_activation_stats_collectors(model, collection_phase): - """Create objects that collect activation statistics. - - This is a utility function that creates two collectors: - 1. Fine-grade sparsity levels of the activations - 2. L1-magnitude of each of the activation channels - - Args: - model - the model on which we want to collect statistics - phase - the statistics collection phase which is either "train" (for training), - or "valid" (for validation) - - WARNING! Enabling activation statsitics collection will significantly slow down training! - """ - class missingdict(dict): - """This is a little trick to prevent KeyError""" - def __missing__(self, key): - return None # note, does *not* set self[key] - we don't want defaultdict's behavior - - distiller.utils.assign_layer_fq_names(model) - - activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()} - if collection_phase is None: - return activations_collectors - collectors = missingdict({ - "sparsity": SummaryActivationStatsCollector(model, "sparsity", - lambda t: 100 * distiller.utils.sparsity(t)), - "l1_channels": SummaryActivationStatsCollector(model, "l1_channels", - distiller.utils.activation_channels_l1), - "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", - distiller.utils.activation_channels_apoz), - "records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) - }) - activations_collectors[collection_phase] = collectors - return activations_collectors - - -def save_collectors_data(collectors, directory): - """Utility function that saves all activation statistics to Excel workbooks - """ - for name, collector in collectors.items(): - workbook = os.path.join(directory, name) - msglogger.info("Generating {}".format(workbook)) - collector.to_xlsx(workbook) - - def main(): global msglogger - check_pytorch_version() - args = parser.parse_args() + + # Parse arguments + prsr = parser.getParser() + distiller.knowledge_distillation.add_distillation_args(prsr, ALL_MODEL_NAMES, True) + args = prsr.parse_args() + if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir) @@ -369,7 +200,7 @@ def main(): msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) - activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats) + activations_collectors = create_activation_stats_collectors(model, *args.activation_stats) if args.sensitivity is not None: sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2]) @@ -784,7 +615,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie loggers = [loggers] test_fnc = partial(test, test_loader=data_loader, criterion=criterion, loggers=loggers, args=args, - activations_collectors=create_activation_stats_collectors(model, None)) + activations_collectors=create_activation_stats_collectors(model)) which_params = [param_name for param_name, _ in model.named_parameters()] sensitivity = distiller.perform_sensitivity_analysis(model, net_params=which_params, @@ -823,8 +654,65 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args): ADC.do_adc(model, args.dataset, args.arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) +def create_activation_stats_collectors(model, *phases): + """Create objects that collect activation statistics. + + This is a utility function that creates two collectors: + 1. Fine-grade sparsity levels of the activations + 2. L1-magnitude of each of the activation channels + + Args: + model - the model on which we want to collect statistics + phases - the statistics collection phases: train, valid, and/or test + + WARNING! Enabling activation statsitics collection will significantly slow down training! + """ + class missingdict(dict): + """This is a little trick to prevent KeyError""" + def __missing__(self, key): + return None # note, does *not* set self[key] - we don't want defaultdict's behavior + + distiller.utils.assign_layer_fq_names(model) + + genCollectors = lambda: missingdict({ + "sparsity": SummaryActivationStatsCollector(model, "sparsity", + lambda t: 100 * distiller.utils.sparsity(t)), + "l1_channels": SummaryActivationStatsCollector(model, "l1_channels", + distiller.utils.activation_channels_l1), + "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", + distiller.utils.activation_channels_apoz), + "records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) + }) + + return {k: (genCollectors() if k in phases else missingdict()) + for k in ('train', 'valid', 'test')} + + +def save_collectors_data(collectors, directory): + """Utility function that saves all activation statistics to Excel workbooks + """ + for name, collector in collectors.items(): + workbook = os.path.join(directory, name) + msglogger.info("Generating {}".format(workbook)) + collector.to_xlsx(workbook) + + +def check_pytorch_version(): + if torch.__version__ < '0.4.0': + print("\nNOTICE:") + print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to " + "PyTorch API changes which are not backward-compatible.\n" + "Please install PyTorch 0.4.0 or its derivative.\n" + "If you are using a virtual environment, do not forget to update it:\n" + " 1. Deactivate the old environment\n" + " 2. Install the new environment\n" + " 3. Activate the new environment") + exit(1) + + if __name__ == '__main__': try: + check_pytorch_version() main() except KeyboardInterrupt: print("\n-- KeyboardInterrupt --") diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f62174423bca586452cae1cb332ddde35786e57e --- /dev/null +++ b/examples/classifier_compression/parser.py @@ -0,0 +1,140 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +import distiller +import models + + +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] + +def getParser(): + parser = argparse.ArgumentParser(description='Distiller image classification model compression') + parser.add_argument('data', metavar='DIR', help='path to dataset') + parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(), + choices=models.ALL_MODEL_NAMES, + help='model architecture: ' + + ' | '.join(models.ALL_MODEL_NAMES) + + ' (default: resnet18)') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') + parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') + parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') + parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), + # choices=["train", "valid", "test"] + help='collect activation statistics on phases: train, valid, and/or test' + ' (WARNING: this slows down training)') + parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, + help='print masks sparsity table at end of each epoch') + parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, + help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)') + parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, + help='print a summary of the model, and exit - options: ' + + ' | '.join(SUMMARY_CHOICES)) + parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', + help='configuration file for pruning the model (default is to use hard-coded schedule)') + parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(), + help='test the sensitivity of layers to pruning') + parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05], + help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n' + 'This is equivalent to creating sensitivities = np.arange(start, stop, step)') + parser.add_argument('--extras', default=None, type=str, + help='file with extra configuration information') + parser.add_argument('--deterministic', '--det', action='store_true', + help='Ensure deterministic execution for re-producible results.') + parser.add_argument('--gpus', metavar='DEV_ID', default=None, + help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') + parser.add_argument('--cpu', action='store_true', default=False, + help='Use CPU only. \n' + 'Flag not set => uses GPUs according to the --gpus flag value.' + 'Flag set => overrides the --gpus flag') + parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') + parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints') + parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, + help='Portion of training dataset to set aside for validation') + parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK') + parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK') + parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', + help='Display the confusion matrix') + parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, + help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)') + parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, + help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)') + parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int, + help='number of best scores to track and report (default: 1)') + parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, + help='Load a model without DataParallel wrapping it') + + str_to_quant_mode_map = { + 'sym': distiller.quantization.LinearQuantMode.SYMMETRIC, + 'asym_s': distiller.quantization.LinearQuantMode.ASYMMETRIC_SIGNED, + 'asym_u': distiller.quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED, + } + + def linear_quant_mode_str(val_str): + try: + return str_to_quant_mode_map[val_str] + except KeyError: + raise argparse.ArgumentError( + 'Must be one of {0} (received {1})'.format( + list(str_to_quant_mode_map), val_str)) + + quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' + '("post-training quantization)') + quant_group.add_argument('--quantize-eval', '--qe', action='store_true', + help='Apply linear quantization to model before evaluation. Applicable only if' + '--evaluate is also set') + quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym', + help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys())) + quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS', + help='Number of bits for quantization of activations') + quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS', + help='Number of bits for quantization of weights') + quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS', + help='Number of bits for quantization of the accumulator') + quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true', + help='Enable clipping of activations using min/max values averaging over batch') + quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[], + help='List of layer names for which not to clip activations. Applicable only if ' + '--qe-clip-acts is also set') + quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true', + help='Enable per-channel quantization of weights (per output channel)') + + return parser + + +def float_range(val_str): + val = float(val_str) + if val < 0 or val >= 1: + raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str)) + return val diff --git a/models/__init__.py b/models/__init__.py index 93545ea56c0de532d441eb4bbb86efcbbe160e58..e7c19de2c87daa921d43a2dded9de698589a2b7b 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -39,7 +39,7 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ if name.islower() and not name.startswith("__") and callable(cifar10_models.__dict__[name])) -ALL_MODEL_NAMES = sorted(set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)) +ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):