Skip to content
Snippets Groups Projects
Commit cfbc3798 authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

compress_classifier.py refactoring (#126)

* Support for multi-phase activations logging

Enable logging activation both durning training and validation at
the same session.

* Refactoring: Move parser to its own file

* Parser is moved from compress_classifier into its own file.
* Torch version check is moved to precede main() call.
* Move main definition to the top of the file.
* Modify parser choices to case-insensitive
parent 4cc0e7d6
No related branches found
No related tags found
No related merge requests found
......@@ -79,190 +79,21 @@ import apputils
from distiller.data_loggers import *
import distiller.quantization as quantization
from models import ALL_MODEL_NAMES, create_model
import parser
# Logger handle
msglogger = None
def float_range(val_str):
val = float(val_str)
if val < 0 or val >= 1:
raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
return val
parser = argparse.ArgumentParser(description='Distiller image classification model compression')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=ALL_MODEL_NAMES,
help='model architecture: ' +
' | '.join(ALL_MODEL_NAMES) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
help='collect activation statistics (WARNING: this slows down training)')
parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
help='print masks sparsity table at end of each epoch')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
help='test the sensitivity of layers to pruning')
parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
'This is equivalent to creating sensitivities = np.arange(start, stop, step)')
parser.add_argument('--extras', default=None, type=str,
help='file with extra configuration information')
parser.add_argument('--deterministic', '--det', action='store_true',
help='Ensure deterministic execution for re-producible results.')
parser.add_argument('--gpus', metavar='DEV_ID', default=None,
help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use CPU only. \n'
'Flag not set => uses GPUs according to the --gpus flag value.'
'Flag set => overrides the --gpus flag')
parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
help='Portion of training dataset to set aside for validation')
parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
help='Display the confusion matrix')
parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)')
parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')
str_to_quant_mode_map = {'sym': quantization.LinearQuantMode.SYMMETRIC,
'asym_s': quantization.LinearQuantMode.ASYMMETRIC_SIGNED,
'asym_u': quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED}
def linear_quant_mode_str(val_str):
try:
return str_to_quant_mode_map[val_str]
except KeyError:
raise argparse.ArgumentError('Must be one of {0} (received {1})'.format(list(str_to_quant_mode_map.keys()),
val_str))
quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
'("post-training quantization)')
quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
help='Apply linear quantization to model before evaluation. Applicable only if'
'--evaluate is also set')
quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of activations')
quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of weights')
quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
help='Number of bits for quantization of the accumulator')
quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true',
help='Enable clipping of activations using min/max values averaging over batch')
quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
help='List of layer names for which not to clip activations. Applicable only if '
'--qe-clip-acts is also set')
quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true',
help='Enable per-channel quantization of weights (per output channel)')
distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)
def check_pytorch_version():
if torch.__version__ < '0.4.0':
print("\nNOTICE:")
print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
"PyTorch API changes which are not backward-compatible.\n"
"Please install PyTorch 0.4.0 or its derivative.\n"
"If you are using a virtual environment, do not forget to update it:\n"
" 1. Deactivate the old environment\n"
" 2. Install the new environment\n"
" 3. Activate the new environment")
exit(1)
def create_activation_stats_collectors(model, collection_phase):
"""Create objects that collect activation statistics.
This is a utility function that creates two collectors:
1. Fine-grade sparsity levels of the activations
2. L1-magnitude of each of the activation channels
Args:
model - the model on which we want to collect statistics
phase - the statistics collection phase which is either "train" (for training),
or "valid" (for validation)
WARNING! Enabling activation statsitics collection will significantly slow down training!
"""
class missingdict(dict):
"""This is a little trick to prevent KeyError"""
def __missing__(self, key):
return None # note, does *not* set self[key] - we don't want defaultdict's behavior
distiller.utils.assign_layer_fq_names(model)
activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()}
if collection_phase is None:
return activations_collectors
collectors = missingdict({
"sparsity": SummaryActivationStatsCollector(model, "sparsity",
lambda t: 100 * distiller.utils.sparsity(t)),
"l1_channels": SummaryActivationStatsCollector(model, "l1_channels",
distiller.utils.activation_channels_l1),
"apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
distiller.utils.activation_channels_apoz),
"records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
})
activations_collectors[collection_phase] = collectors
return activations_collectors
def save_collectors_data(collectors, directory):
"""Utility function that saves all activation statistics to Excel workbooks
"""
for name, collector in collectors.items():
workbook = os.path.join(directory, name)
msglogger.info("Generating {}".format(workbook))
collector.to_xlsx(workbook)
def main():
global msglogger
check_pytorch_version()
args = parser.parse_args()
# Parse arguments
prsr = parser.getParser()
distiller.knowledge_distillation.add_distillation_args(prsr, ALL_MODEL_NAMES, True)
args = prsr.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)
......@@ -369,7 +200,7 @@ def main():
msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)
if args.sensitivity is not None:
sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
......@@ -784,7 +615,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie
loggers = [loggers]
test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
loggers=loggers, args=args,
activations_collectors=create_activation_stats_collectors(model, None))
activations_collectors=create_activation_stats_collectors(model))
which_params = [param_name for param_name, _ in model.named_parameters()]
sensitivity = distiller.perform_sensitivity_analysis(model,
net_params=which_params,
......@@ -823,8 +654,65 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
ADC.do_adc(model, args.dataset, args.arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
def create_activation_stats_collectors(model, *phases):
"""Create objects that collect activation statistics.
This is a utility function that creates two collectors:
1. Fine-grade sparsity levels of the activations
2. L1-magnitude of each of the activation channels
Args:
model - the model on which we want to collect statistics
phases - the statistics collection phases: train, valid, and/or test
WARNING! Enabling activation statsitics collection will significantly slow down training!
"""
class missingdict(dict):
"""This is a little trick to prevent KeyError"""
def __missing__(self, key):
return None # note, does *not* set self[key] - we don't want defaultdict's behavior
distiller.utils.assign_layer_fq_names(model)
genCollectors = lambda: missingdict({
"sparsity": SummaryActivationStatsCollector(model, "sparsity",
lambda t: 100 * distiller.utils.sparsity(t)),
"l1_channels": SummaryActivationStatsCollector(model, "l1_channels",
distiller.utils.activation_channels_l1),
"apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
distiller.utils.activation_channels_apoz),
"records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
})
return {k: (genCollectors() if k in phases else missingdict())
for k in ('train', 'valid', 'test')}
def save_collectors_data(collectors, directory):
"""Utility function that saves all activation statistics to Excel workbooks
"""
for name, collector in collectors.items():
workbook = os.path.join(directory, name)
msglogger.info("Generating {}".format(workbook))
collector.to_xlsx(workbook)
def check_pytorch_version():
if torch.__version__ < '0.4.0':
print("\nNOTICE:")
print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
"PyTorch API changes which are not backward-compatible.\n"
"Please install PyTorch 0.4.0 or its derivative.\n"
"If you are using a virtual environment, do not forget to update it:\n"
" 1. Deactivate the old environment\n"
" 2. Install the new environment\n"
" 3. Activate the new environment")
exit(1)
if __name__ == '__main__':
try:
check_pytorch_version()
main()
except KeyboardInterrupt:
print("\n-- KeyboardInterrupt --")
......
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import distiller
import models
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
def getParser():
parser = argparse.ArgumentParser(description='Distiller image classification model compression')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(),
choices=models.ALL_MODEL_NAMES,
help='model architecture: ' +
' | '.join(models.ALL_MODEL_NAMES) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
# choices=["train", "valid", "test"]
help='collect activation statistics on phases: train, valid, and/or test'
' (WARNING: this slows down training)')
parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
help='print masks sparsity table at end of each epoch')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(),
help='test the sensitivity of layers to pruning')
parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
'This is equivalent to creating sensitivities = np.arange(start, stop, step)')
parser.add_argument('--extras', default=None, type=str,
help='file with extra configuration information')
parser.add_argument('--deterministic', '--det', action='store_true',
help='Ensure deterministic execution for re-producible results.')
parser.add_argument('--gpus', metavar='DEV_ID', default=None,
help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use CPU only. \n'
'Flag not set => uses GPUs according to the --gpus flag value.'
'Flag set => overrides the --gpus flag')
parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
help='Portion of training dataset to set aside for validation')
parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
help='Display the confusion matrix')
parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)')
parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')
str_to_quant_mode_map = {
'sym': distiller.quantization.LinearQuantMode.SYMMETRIC,
'asym_s': distiller.quantization.LinearQuantMode.ASYMMETRIC_SIGNED,
'asym_u': distiller.quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED,
}
def linear_quant_mode_str(val_str):
try:
return str_to_quant_mode_map[val_str]
except KeyError:
raise argparse.ArgumentError(
'Must be one of {0} (received {1})'.format(
list(str_to_quant_mode_map), val_str))
quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
'("post-training quantization)')
quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
help='Apply linear quantization to model before evaluation. Applicable only if'
'--evaluate is also set')
quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of activations')
quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of weights')
quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
help='Number of bits for quantization of the accumulator')
quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true',
help='Enable clipping of activations using min/max values averaging over batch')
quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
help='List of layer names for which not to clip activations. Applicable only if '
'--qe-clip-acts is also set')
quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true',
help='Enable per-channel quantization of weights (per output channel)')
return parser
def float_range(val_str):
val = float(val_str)
if val < 0 or val >= 1:
raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
return val
......@@ -39,7 +39,7 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
if name.islower() and not name.startswith("__")
and callable(cifar10_models.__dict__[name]))
ALL_MODEL_NAMES = sorted(set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment