diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 98cadfdd2fa276109b9c27ffd6074a63df6bc6de..ab3b5b9dcd7136ae31c7a4f8e76c7d5e021a7986 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -56,7 +56,7 @@ class ClassifierCompressor(object): """ def __init__(self, args, script_dir): self.args = copy.deepcopy(args) - _infer_implicit_args(self.args) + self._infer_implicit_args(self.args) self.logdir = _init_logger(self.args, script_dir) _config_determinism(self.args) _config_compute_device(self.args) @@ -87,6 +87,25 @@ class ClassifierCompressor(object): def data_loaders(self): return self.train_loader, self.val_loader, self.test_loader + @staticmethod + def _infer_implicit_args(args): + # Infer the dataset from the model name + if not hasattr(args, 'dataset'): + args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) + if not hasattr(args, "num_classes"): + args.num_classes = distiller.apputils.classification_num_classes(args.dataset) + return args + + @staticmethod + def mock_args(): + """Generate a Namespace based on default arguments""" + return ClassifierCompressor._infer_implicit_args( + init_classifier_compression_arg_parser().parse_args(['fictive_required_arg',])) + + @classmethod + def mock_classifier(cls): + return cls(cls.mock_args(), '') + def train_one_epoch(self, epoch, verbose=True): """Train for one epoch""" self.load_datasets() @@ -359,14 +378,6 @@ def _config_compute_device(args): torch.cuda.set_device(args.gpus[0]) -def _infer_implicit_args(args): - # Infer the dataset from the model name - if not hasattr(args, 'dataset'): - args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) - if not hasattr(args, "num_classes"): - args.num_classes = distiller.apputils.classification_num_classes(args.dataset) - - def _init_learner(args): # Create the model model = create_model(args.pretrained, args.dataset, args.arch, @@ -621,11 +632,14 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1): return _validate(val_loader, model, criterion, loggers, args, epoch) -def test(test_loader, model, criterion, loggers, activations_collectors, args): +def test(test_loader, model, criterion, loggers=None, activations_collectors=None, args=None): """Model Test""" msglogger.info('--- test ---------------------') + if args is None: + args = ClassifierCompressor.mock_args() if activations_collectors is None: activations_collectors = create_activation_stats_collectors(model, None) + with collectors_context(activations_collectors["test"]) as collectors: top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args) distiller.log_activation_statistics(-1, "test", loggers, collector=collectors['sparsity']) @@ -820,7 +834,7 @@ def earlyexit_validate_stats(args): return total_top1, total_top5, losses_exits_stats -def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args, scheduler=None): +def evaluate_model(test_loader, model, criterion, loggers, activations_collectors=None, args=None, scheduler=None): # This sample application can be invoked to evaluate the accuracy of your model on # the test dataset. # You can optionally quantize the model to 8-bit integer before evaluation. @@ -830,30 +844,66 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector if not isinstance(loggers, list): loggers = [loggers] - if args.quantize_eval: - model.cpu() - quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) - quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape)) - model.to(args.device) + if not args.quantize_eval: + return test(test_loader, model, criterion, loggers, activations_collectors, args=args) + else: + return quantize_and_test_model(test_loader, model, criterion, args, loggers, + scheduler=scheduler, save_flag=True) + - top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args) +def quantize_and_test_model(test_loader, model, criterion, args, loggers=None, scheduler=None, save_flag=True): + """Collect stats using test_loader (when stats file is absent), - if args.quantize_eval: + clone the model and quantize the clone, and finally, test it. + args.device is allowed to differ from the model's device. + When args.qe_calibration is set to None, uses 0.05 instead. + + scheduler - pass scheduler to store it in checkpoint + save_flag - defaults to save both quantization statistics and checkpoint. + """ + if not (args.qe_dynamic or args.qe_stats_file or args.qe_config_file): + args_copy = copy.deepcopy(args) + args_copy.qe_calibration = args.qe_calibration if args.qe_calibration is not None else 0.05 + + # set stats into args stats field + args.qe_stats_file = acts_quant_stats_collection( + model, criterion, loggers, args_copy, save_to_file=save_flag) + + args_qe = copy.deepcopy(args) + if args.device == 'cpu': + # NOTE: Even though args.device is CPU, we allow here that model is not in CPU. + qe_model = distiller.make_non_parallel_copy(model).cpu() + else: + qe_model = copy.deepcopy(model).to(args.device) + + quantizer = quantization.PostTrainLinearQuantizer.from_args(qe_model, args_qe) + quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape)) + + test_res = test(test_loader, qe_model, criterion, loggers, args=args_qe) + + if save_flag: checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=scheduler, - name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name, - dir=msglogger.logdir, extras={'quantized_top1': top1}) + apputils.save_checkpoint(0, args_qe.arch, qe_model, scheduler=scheduler, + name='_'.join([args_qe.name, checkpoint_name]) if args_qe.name else checkpoint_name, + dir=msglogger.logdir, extras={'quantized_top1': test_res[0]}) -def acts_quant_stats_collection(model, criterion, loggers, args): + del qe_model + return test_res + + +def acts_quant_stats_collection(model, criterion, loggers, args, test_loader=None, save_to_file=False): msglogger.info('Collecting quantization calibration stats based on {:.1%} of test dataset' .format(args.qe_calibration)) - model = distiller.utils.make_non_parallel_copy(model) - args.effective_test_size = args.qe_calibration - test_loader = load_data(args, fixed_subset=True, load_train=False, load_val=False) + if test_loader is None: + tmp_args = copy.deepcopy(args) + tmp_args.effective_test_size = tmp_args.qe_calibration + test_loader = load_data(tmp_args, fixed_subset=True, load_train=False, load_val=False) test_fn = partial(test, test_loader=test_loader, criterion=criterion, loggers=loggers, args=args, activations_collectors=None) - collect_quant_stats(model, test_fn, save_dir=msglogger.logdir, - classes=None, inplace_runtime_check=True, disable_inplace_attrs=True) + with distiller.get_nonparallel_clone_model(model) as cmodel: + return collect_quant_stats(cmodel, test_fn, classes=None, + inplace_runtime_check=True, disable_inplace_attrs=True, + save_dir=msglogger.logdir if save_to_file else None) def acts_histogram_collection(model, criterion, loggers, args): diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index a7621d9d3fb33d848a8d4a2ec6a2e52fd9284d41..d3f98cb79eaefb0d7f13434342d6a68741155cf1 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -238,7 +238,11 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): records_dict = self.value() with xlsxwriter.Workbook(fname) as workbook: - worksheet = workbook.add_worksheet(self.stat_name) + try: + worksheet = workbook.add_worksheet(self.stat_name) + except xlsxwriter.exceptions.InvalidWorksheetName: + worksheet = workbook.add_worksheet() + col_names = [] for col, (module_name, module_summary_data) in enumerate(records_dict.items()): if not isinstance(module_summary_data, list): @@ -316,7 +320,11 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): records_dict = self.value() with xlsxwriter.Workbook(fname) as workbook: for module_name, module_act_records in records_dict.items(): - worksheet = workbook.add_worksheet(module_name) + try: + worksheet = workbook.add_worksheet(module_name) + except xlsxwriter.exceptions.InvalidWorksheetName: + worksheet = workbook.add_worksheet() + col_names = [] for col, (col_name, col_data) in enumerate(module_act_records.items()): if col_name == 'shape': diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 3360ed564a832d1069fe619afee16d7e45814110..a5e03b4b3f131c18aa210cd0f5418e8e4297e068 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -236,14 +236,9 @@ def add_post_train_quant_args(argparser): group = argparser.add_argument_group('Arguments controlling quantization at evaluation time ' '("post-training quantization")') - exc_group = group.add_mutually_exclusive_group() - exc_group.add_argument('--quantize-eval', '--qe', action='store_true', + group.add_argument('--quantize-eval', '--qe', action='store_true', help='Apply linear quantization to model before evaluation. Applicable only if ' '--evaluate is also set') - exc_group.add_argument('--qe-calibration', type=distiller.utils.float_range_argparse_checker(exc_min=True), - metavar='PORTION_OF_TEST_SET', - help='Run the model in evaluation mode on the specified portion of the test dataset and ' - 'collect statistics. Ignores all other \'qe--*\' arguments') group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym', help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys())) group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS', @@ -264,10 +259,16 @@ def add_post_train_quant_args(argparser): group.add_argument('--qe-scale-approx-bits', '--qesab', type=int, metavar='NUM_BITS', help='Enables scale factor approximation using integer multiply + bit shift, using ' 'this number of bits the integer multiplier') - group.add_argument('--qe-stats-file', type=str, metavar='PATH', - help='Path to YAML file with calibration stats. If not given, dynamic quantization will ' - 'be run (Note that not all layer types are supported for dynamic quantization)') - group.add_argument('--qe-config-file', type=str, metavar='PATH', + + stats_group = group.add_mutually_exclusive_group() + stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH', + help='Path to YAML file with pre-made calibration stats') + stats_group.add_argument('--qe-dynamic', action='store_true', help='Apply dynamic quantization') + stats_group.add_argument('--qe-calibration', type=distiller.utils.float_range_argparse_checker(exc_min=True), + metavar='PORTION_OF_TEST_SET', default=None, + help='Run the model in evaluation mode on the specified portion of the test dataset and ' + 'collect statistics. Ignores all other \'qe--*\' arguments') + stats_group.add_argument('--qe-config-file', type=str, metavar='PATH', help='Path to YAML file containing configuration for PostTrainLinearQuantizer (if present, ' 'all other --qe* arguments are ignored)') @@ -1276,7 +1277,7 @@ class PostTrainLinearQuantizer(Quantizer): mode=args.qe_mode, clip_acts=args.qe_clip_acts, per_channel_wts=args.qe_per_channel, - model_activation_stats=args.qe_stats_file, + model_activation_stats=(None if args.qe_dynamic else args.qe_stats_file), clip_n_stds=args.qe_clip_n_stds, scale_approx_mult_bits=args.qe_scale_approx_bits, overrides=overrides, diff --git a/distiller/utils.py b/distiller/utils.py index f50adca41863b18c4f00f2e13966018936c61cd7..63127ff8b818637cf97a7ea7de685217c4dc151f 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -21,6 +21,7 @@ with some random helper functions. """ import argparse from collections import OrderedDict +import contextlib from copy import deepcopy import logging import operator @@ -516,6 +517,8 @@ def log_activation_statistics(epoch, phase, loggers, collector): """Log information about the sparsity of the activations""" if collector is None: return + if loggers is None: + return for logger in loggers: logger.log_activation_statistic(phase, collector.stat_name, collector.value(), epoch) @@ -644,6 +647,15 @@ def make_non_parallel_copy(model): return new_model +@contextlib.contextmanager +def get_nonparallel_clone_model(model): + clone_model = make_non_parallel_copy(model) + try: + yield clone_model + finally: + del clone_model + + def set_seed(seed): """Seed the PRNG for the CPU, Cuda, numpy and Python""" torch.manual_seed(seed) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 09dab168a94848206cc4839e5ae36489f734ee1a..8fe0d4bbc249b13e136033aeb0b30e8af4e6c07a 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -98,7 +98,7 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger, os.path.join(msglogger.logdir, args.export_onnx), args.dataset, add_softmax=True, verbose=False) do_exit = True - elif args.qe_calibration: + elif args.qe_calibration and not (args.evaluate and args.quantize_eval): classifier.acts_quant_stats_collection(model, criterion, pylogger, args) do_exit = True elif args.activation_histograms: @@ -111,9 +111,9 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger, do_exit = True elif args.evaluate: test_loader = load_test_data(args) - activations_collectors = classifier.create_activation_stats_collectors(model, *args.activation_stats) - classifier.evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, - args, compression_scheduler) + classifier.evaluate_model(test_loader, model, criterion, pylogger, + classifier.create_activation_stats_collectors(model, *args.activation_stats), + args, scheduler=compression_scheduler) do_exit = True elif args.thinnify: assert args.resumed_checkpoint_path is not None, \ diff --git a/requirements.txt b/requirements.txt index 11bae09c4644fc77c15ea3792fd549a0c21d1a16..b6c115a4c037ddf8a468d57f1aad2f8de8734954 100755 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ ipywidgets==7.4.2 bqplot==0.11.5 pyyaml pytest~=4.6.1 -xlsxwriter>=1.1.1 +xlsxwriter>=1.2.2 pretrainedmodels==0.7.4 scikit-learn==0.21.2 gym==0.12.5 diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 6704f203d4237427ba9deaa0e92929ffd39485cf..407fe2074cab943685a283802746e4ccf23c9eba 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -157,7 +157,7 @@ test_configs = [ (18.160, 65.310), (17.04, 64.42)]), TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]), - TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'. + TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-dynamic --qe-clip-acts avg --qe-no-clip-layers {1}'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'), DS_CIFAR, accuracy_checker, [91.57, 99.62]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.