diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index 91005ebc7bb1407d2b73840ec611ee6ca6ebacf8..a3c4ae62463c10ec95dc6b1e7fdfb73503c225b9 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -60,7 +60,8 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, if best_top1 is not None: checkpoint['best_top1'] = best_top1 if optimizer is not None: - checkpoint['optimizer'] = optimizer.state_dict() + checkpoint['optimizer_state_dict'] = optimizer.state_dict() + checkpoint['optimizer_type'] = type(optimizer) if scheduler is not None: checkpoint['compression_sched'] = scheduler.state_dict() if hasattr(model, 'thinning_recipes'): @@ -73,13 +74,22 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, shutil.copyfile(fullpath, fullpath_best) -def load_checkpoint(model, chkpt_file, optimizer=None): - """Load a pytorch training checkpoint +def load_lean_checkpoint(model, chkpt_file, model_device=None): + return load_checkpoint(model, chkpt_file, model_device=model_device, + lean_checkpoint=True)[0] + + +def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False): + """Load a pytorch training checkpoint. Args: model: the pytorch model to which we will load the parameters chkpt_file: the checkpoint file - optimizer: the optimizer to which we will load the serialized state + lean_checkpoint: if set, read into model only 'state_dict' field + optimizer: [deprecated argument] + model_device [str]: if set, call model.to($model_device) + This should be set to either 'cpu' or 'cuda'. + :returns: updated model, compression_scheduler, optimizer, start_epoch """ if not os.path.isfile(chkpt_file): raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) @@ -132,9 +142,43 @@ def load_checkpoint(model, chkpt_file, optimizer=None): quantizer = qmd['type'](model, **qmd['params']) quantizer.prepare_model() - msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), - e=checkpoint_epoch)) if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} model.load_state_dict(checkpoint['state_dict']) - return (model, compression_scheduler, start_epoch) + if model_device is not None: + model.to(model_device) + + if lean_checkpoint: + msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file))) + return (model, None, None, 0) + + def _load_optimizer(cls, src_state_dict, model): + """Initiate optimizer with model parameters and load src_state_dict""" + # initiate the dest_optimizer with a dummy learning rate, + # this is required to support SGD.__init__() + dest_optimizer = cls(model.parameters(), lr=1) + dest_optimizer.load_state_dict(src_state_dict) + return dest_optimizer + + try: + optimizer = _load_optimizer(checkpoint['optimizer_type'], + checkpoint['optimizer_state_dict'], model) + except KeyError: + if 'optimizer' not in checkpoint: + raise + # older checkpoints didn't support this feature + # they had the 'optimizer' field instead + optimizer = None + + if optimizer is not None: + msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format( + type=type(optimizer))) + msglogger.info('Optimizer Args: {}'.format( + dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items() + if k != 'params'))) + else: + msglogger.warning('Optimizer could not be loaded from checkpoint.') + + msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), + e=checkpoint_epoch)) + return (model, compression_scheduler, optimizer, start_epoch) diff --git a/distiller/config.py b/distiller/config.py index c0fb982efc0b9469a187f3f665b3380b0db7a149..9367e88e70abab1be7cbaf21e43dfb0f44a3a88b 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -49,7 +49,7 @@ msglogger = logging.getLogger() app_cfg_logger = logging.getLogger("app_cfg") -def dict_config(model, optimizer, sched_dict, scheduler=None): +def dict_config(model, optimizer, sched_dict, scheduler=None, resumed_epoch=None): app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2)) if scheduler is None: @@ -110,7 +110,8 @@ def dict_config(model, optimizer, sched_dict, scheduler=None): add_policy_to_scheduler(policy, policy_def, scheduler) # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers - lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) + lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer, + last_epoch=(resumed_epoch if resumed_epoch is not None else -1)) for policy_def in lr_policies: instance_name, args = __policy_params(policy_def, 'lr_scheduler') assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format( @@ -138,13 +139,13 @@ def add_policy_to_scheduler(policy, policy_def, scheduler): frequency=policy_def['frequency']) -def file_config(model, optimizer, filename, scheduler=None): +def file_config(model, optimizer, filename, scheduler=None, resumed_epoch=None): """Read the schedule from file""" with open(filename, 'r') as stream: msglogger.info('Reading compression schedule from: %s', filename) try: sched_dict = distiller.utils.yaml_ordered_load(stream) - return dict_config(model, optimizer, sched_dict, scheduler) + return dict_config(model, optimizer, sched_dict, scheduler, resumed_epoch) except yaml.YAMLError as exc: print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename) raise diff --git a/distiller/policy.py b/distiller/policy.py index f42219f94322fb27177e126e796582832f0a8538..db45e6f18a4efd0f9a04fb88b5c53cd0b6214d52 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -21,6 +21,7 @@ - LRPolicy: learning-rate decay scheduling """ import torch +import torch.optim.lr_scheduler from collections import namedtuple import logging @@ -42,7 +43,7 @@ class ScheduledTrainingPolicy(object): self.classes = classes self.layers = layers - def on_epoch_begin(self, model, zeros_mask_dict, meta): + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): """A new epcoh is about to begin""" pass @@ -115,7 +116,7 @@ class PruningPolicy(ScheduledTrainingPolicy): self.mini_batch_id = 0 # The ID of the mini_batch within the present epoch self.global_mini_batch_id = 0 # The ID of the mini_batch within the present training session - def on_epoch_begin(self, model, zeros_mask_dict, meta): + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): msglogger.debug("Pruner {} is about to prune".format(self.pruner.name)) self.mini_batch_id = 0 self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) @@ -169,7 +170,7 @@ class RegularizationPolicy(ScheduledTrainingPolicy): self.keep_mask = keep_mask self.is_last_epoch = False - def on_epoch_begin(self, model, zeros_mask_dict, meta): + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, @@ -202,15 +203,19 @@ class RegularizationPolicy(ScheduledTrainingPolicy): class LRPolicy(ScheduledTrainingPolicy): - """ Learning-rate decay scheduling policy. + """Learning-rate decay scheduling policy. """ def __init__(self, lr_scheduler): super(LRPolicy, self).__init__() self.lr_scheduler = lr_scheduler - def on_epoch_begin(self, model, zeros_mask_dict, meta): - self.lr_scheduler.step() + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + # Note: ReduceLROnPlateau doesn't inherit from _LRScheduler + self.lr_scheduler.step(kwargs['metrics'], epoch=meta['current_epoch']) + else: + self.lr_scheduler.step(epoch=meta['current_epoch']) class QuantizationPolicy(ScheduledTrainingPolicy): diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 2348eea106dc13636e50b7841d40d4f69b4a94d3..7104ac19815093ac6fd0a4bd630391b8ae36590b 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -106,12 +106,12 @@ class CompressionScheduler(object): 'ending_epoch': ending_epoch, 'frequency': frequency} - def on_epoch_begin(self, epoch, optimizer=None): - if epoch in self.policies: - for policy in self.policies[epoch]: - meta = self.sched_metadata[policy] - meta['current_epoch'] = epoch - policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta) + def on_epoch_begin(self, epoch, optimizer=None, **kwargs): + for policy in self.policies.get(epoch, list()): + meta = self.sched_metadata[policy] + meta['current_epoch'] = epoch + policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta, + **kwargs) def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): if epoch in self.policies: diff --git a/distiller/thinning.py b/distiller/thinning.py index 94758369c3b051e3c0ae5e0a5fea3e0c29e6c20e..78cd2879469b1c62f40094f2702f4e4e2e6ecedb 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -463,8 +463,12 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): This function is brittle as it is tested on SGD only and relies on the internal representation of the SGD optimizer, which can change w/o notice. """ - if optimizer is None or not isinstance(optimizer, torch.optim.SGD): + if optimizer is None: return False + + if not isinstance(optimizer, torch.optim.SGD): + raise NotImplementedError('optimizer thinning supports only SGD') + for group in optimizer.param_groups: momentum = group.get('momentum', 0) if momentum == 0: @@ -473,7 +477,7 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): if id(p) != id(param): continue param_state = optimizer.state[p] - if 'momentum_buffer' in param_state: + if param_state.get('momentum_buffer', None) is not None: param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices) if new_shape is not None: msglogger.debug("optimizer_thinning: new shape {}".format(*new_shape)) @@ -526,9 +530,6 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr grad_selection_view = param.grad.resize_(*directive[2]) if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) - if optimizer_thinning(optimizer, param, dim, indices, directive[3]): - msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". - format(param_name, dim, len_indices, directive[3])) param.data = param.view(*directive[3]) if param.grad is not None: @@ -542,8 +543,13 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr # not exist, and therefore won't need to be re-dimensioned. if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) - if optimizer_thinning(optimizer, param, dim, indices): - msglogger.debug("Updated velocity buffer %s" % param_name) + + # update optimizer + if optimizer_thinning(optimizer, param, dim, indices, + new_shape=directive[3] if len(directive)==4 else None): + msglogger.debug("Updated velocity buffer %s" % param_name) + else: + msglogger.debug('Failed to update the optimizer by thinning directive') if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change diff --git a/distiller/utils.py b/distiller/utils.py index 875b4043e1c8a3cb86037edd9965500db3a74605..ac4a31abe323ef85948c9fce20e2c02c9813fe1c 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -39,6 +39,10 @@ def model_device(model): return 'cpu' +def optimizer_device_name(opt): + return str(list(list(opt.state)[0])[0].device) + + def to_np(var): return var.data.cpu().numpy() diff --git a/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml b/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml index 6e587355ad79523db6bcdbba28e2392c5dfeabd6..dd51be800c79caab762907c503575b18b35b1ccb 100755 --- a/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml +++ b/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml @@ -1,4 +1,5 @@ -# time python3 compress_classifier.py -a=mobilenet -p=50 --lr=0.001 ../../../data.imagenet/ -j=22 --resume=mobilenet_sgd_68.848.pth.tar --epochs=96 --compress=../agp-pruning//mobilenet.imagenet.schedule_agp.yaml +# time python3 compress_classifier.py -a=mobilenet -p=50 --lr=0.001 ../../../data.imagenet/ -j=22 --resume-from=mobilenet_sgd_68.848.pth.tar --epochs=96 --compress=../agp-pruning//mobilenet.imagenet.schedule_agp.yaml --reset-optimizer --vs=0 +# # # A pretrained MobileNet (width=1) can be downloaded from: https://github.com/marvis/pytorch-mobilenet (top1: 68.848; top5: 88.740) # @@ -86,18 +87,18 @@ lr_schedulers: policies: - pruner: instance_name : 'conv50_pruner' - starting_epoch: 103 - ending_epoch: 123 + starting_epoch: 0 + ending_epoch: 20 frequency: 2 - pruner: instance_name : 'conv60_pruner' - starting_epoch: 103 - ending_epoch: 123 + starting_epoch: 0 + ending_epoch: 20 frequency: 2 - lr_scheduler: instance_name: pruning_lr - starting_epoch: 103 - ending_epoch: 200 + starting_epoch: 0 + ending_epoch: 100 frequency: 1 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml index aa04ed5253a80803336d30dc3d1deba68b4a3810..362e52c00c800528060502c14002e65cddf3f07c 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml @@ -14,7 +14,7 @@ # Total sparsity: 41.10 # # of parameters: 120,000 (=55.7% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -109,29 +109,29 @@ extensions: policies: - pruner: instance_name : low_pruner - starting_epoch: 180 - ending_epoch: 210 + starting_epoch: 0 + ending_epoch: 30 frequency: 2 - pruner: instance_name : fine_pruner - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 - pruner: instance_name : fc_pruner - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner - epochs: [212] + epochs: [32] - lr_scheduler: instance_name: pruning_lr - starting_epoch: 180 + starting_epoch: 0 ending_epoch: 400 frequency: 1 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml index 4c9d3b5005a40a7cf74a1f7d69323e221f80daf8..1ea06610ad142adb4b641f1f4ab0c0f03dc7e7af 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml @@ -14,8 +14,7 @@ # Total sparsity: 41.84 # # of parameters: 143,488 (=53% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar -# +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -103,20 +102,20 @@ extensions: policies: - pruner: instance_name : low_pruner - starting_epoch: 180 - ending_epoch: 200 + starting_epoch: 0 + ending_epoch: 20 frequency: 2 - pruner: instance_name : fine_pruner - starting_epoch: 200 - ending_epoch: 220 + starting_epoch: 20 + ending_epoch: 40 frequency: 2 # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner - epochs: [202] + epochs: [22] - lr_scheduler: instance_name: pruning_lr diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml index 38798587c8b421a69f7a4a42a6d01116f7d247dd..f958ec13ce230b5764e1bcec22eeef49d1671752 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml @@ -14,7 +14,7 @@ # Total sparsity: 56.41% # # of parameters: 95922 (=35.4% of the baseline parameters ==> 64.6% sparsity) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -116,26 +116,26 @@ extensions: policies: - pruner: instance_name : low_pruner - starting_epoch: 180 - ending_epoch: 210 + starting_epoch: 0 + ending_epoch: 30 frequency: 2 - pruner: instance_name : fine_pruner1 - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 - pruner: instance_name : fine_pruner2 - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 - pruner: instance_name : fc_pruner - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 # Currently the thinner is disabled until the the structure pruner is done, because it interacts @@ -150,10 +150,10 @@ policies: # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner - epochs: [212] + epochs: [32] - lr_scheduler: instance_name: pruning_lr - starting_epoch: 180 + starting_epoch: 0 ending_epoch: 400 frequency: 1 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml index 3f85c2db2c780446a71887a976ab145da41baf96..147d7b8e0cd86b3afa6af9545d397f645b6f2736 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml @@ -14,7 +14,7 @@ # Total sparsity: 39.66 # # of parameters: 78,776 (=29.1% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -110,26 +110,26 @@ extensions: policies: - pruner: instance_name : low_pruner - starting_epoch: 180 - ending_epoch: 210 + starting_epoch: 0 + ending_epoch: 30 frequency: 2 - pruner: instance_name : low_pruner_2 - starting_epoch: 180 - ending_epoch: 210 + starting_epoch: 0 + ending_epoch: 30 frequency: 2 - pruner: instance_name : fine_pruner - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 - pruner: instance_name : fc_pruner - starting_epoch: 210 - ending_epoch: 230 + starting_epoch: 30 + ending_epoch: 50 frequency: 2 # Currently the thinner is disabled until the the structure pruner is done, because it interacts @@ -144,11 +144,10 @@ policies: # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner - #epochs: [181] - epochs: [212] + epochs: [32] - lr_scheduler: instance_name: pruning_lr - starting_epoch: 180 + starting_epoch: 0 ending_epoch: 400 frequency: 1 diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index c33cc3fa60d469d06dd51545663d3d6fccc35f3b..e7b318c937969d98d4ee09e661161d5d4320bdea 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -87,6 +87,8 @@ def main(): # Parse arguments args = parser.get_parser().parse_args() + if args.epochs is None: + args.epochs = 90 if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) @@ -98,6 +100,7 @@ def main(): msglogger.debug("Distiller: %s", distiller.__version__) start_epoch = 0 + ending_epoch = args.epochs perf_scores_history = [] if args.deterministic: # Experiment reproducibility is sometimes important. Pete Warden expounded about this @@ -155,19 +158,36 @@ def main(): if args.earlyexit_thresholds: msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds) - # We can optionally resume from a checkpoint - if args.resume: - model, compression_scheduler, start_epoch = apputils.load_checkpoint(model, chkpt_file=args.resume) - model.to(args.device) + # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1 + if args.deprecated_resume: + msglogger.warning('The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.') + if not args.reset_optimizer: + msglogger.warning('If you wish to also reset the optimizer, call with: --reset-optimizer') + args.reset_optimizer = True + args.resumed_checkpoint_path = args.deprecated_resume - # Define loss function (criterion) and optimizer + # We can optionally resume from a checkpoint + optimizer = None + if args.resumed_checkpoint_path: + model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( + model, args.resumed_checkpoint_path, model_device=args.device) + elif args.load_model_path: + model = apputils.load_lean_checkpoint(model, args.load_model_path, + model_device=args.device) + if args.reset_optimizer: + start_epoch = 0 + if optimizer is not None: + optimizer = None + msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0') + + # Define loss function (criterion) criterion = nn.CrossEntropyLoss().to(args.device) - optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - msglogger.info('Optimizer Type: %s', type(optimizer)) - msglogger.info('Optimizer Args: %s', optimizer.defaults) + if optimizer is None: + optimizer = torch.optim.SGD(model.parameters(), + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + msglogger.info('Optimizer Type: %s', type(optimizer)) + msglogger.info('Optimizer Args: %s', optimizer.defaults) if args.AMC: return automated_deep_compression(model, criterion, optimizer, pylogger, args) @@ -211,7 +231,8 @@ def main(): if args.compress: # The main use-case for this sample application is CNN compression. Compression # requires a compression schedule configuration file in YAML. - compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler) + compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler, + (start_epoch-1) if args.resumed_checkpoint_path else None) # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.to(args.device) elif compression_scheduler is None: @@ -219,10 +240,10 @@ def main(): if args.thinnify: #zeros_mask_dict = distiller.create_model_masks_dict(model) - assert args.resume is not None, "You must use --resume to provide a checkpoint file to thinnify" + assert args.resumed_checkpoint_path is not None, "You must use --resume-from to provide a checkpoint file to thinnify" distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler, - name="{}_thinned".format(args.resume.replace(".pth.tar", "")), dir=msglogger.logdir) + name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")), dir=msglogger.logdir) print("Note: your model may have collapsed to random inference, so you may want to fine-tune") return @@ -230,7 +251,7 @@ def main(): if args.kd_teacher: teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus) if args.kd_resume: - teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume) + teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume) dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt) args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw) compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs, @@ -243,11 +264,17 @@ def main(): ' | '.join(['{:.2f}'.format(val) for val in dlw])) msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch) - for epoch in range(start_epoch, start_epoch + args.epochs): + if start_epoch >= ending_epoch: + msglogger.error( + 'epoch count is too low, starting epoch is {} but total epochs set to {}'.format( + start_epoch, ending_epoch)) + raise ValueError('Epochs parameter is too low. Nothing to do.') + for epoch in range(start_epoch, ending_epoch): # This is the main training loop. msglogger.info('\n') if compression_scheduler: - compression_scheduler.on_epoch_begin(epoch) + compression_scheduler.on_epoch_begin(epoch, + metrics=(vloss if (epoch != start_epoch) else 10**6)) # Train for one epoch with collectors_context(activations_collectors["train"]) as collectors: @@ -595,7 +622,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector # the test dataset. # You can optionally quantize the model to 8-bit integer before evaluation. # For example: - # python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate + # python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume-from=checkpoint.pth.tar --evaluate if not isinstance(loggers, list): loggers = [loggers] diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 9b51f513031c3ae40654db503f0704f158ee23d4..2783b29c311edad9a7f5da1132d7c0978055b2ac 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -36,24 +36,40 @@ def get_parser(): ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') + parser.add_argument('--epochs', type=int, metavar='N', + help='number of total epochs to run (default: 90') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') - parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') + + optimizer_args = parser.add_argument_group('Optimizer arguments') + optimizer_args.add_argument('--lr', '--learning-rate', default=0.1, + type=float, metavar='LR', help='initial learning rate') + optimizer_args.add_argument('--momentum', default=0.9, type=float, + metavar='M', help='momentum') + optimizer_args.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') + parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') - parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') + + load_checkpoint_group = parser.add_argument_group('Resuming arguments') + load_checkpoint_group_exc = load_checkpoint_group.add_mutually_exclusive_group() + # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1 + load_checkpoint_group_exc.add_argument('--resume', dest='deprecated_resume', default='', type=str, + metavar='PATH', help=argparse.SUPPRESS) + load_checkpoint_group_exc.add_argument('--resume-from', dest='resumed_checkpoint_path', default='', + type=str, metavar='PATH', + help='path to latest checkpoint. Use to resume paused training session.') + load_checkpoint_group_exc.add_argument('--exp-load-weights-from', dest='load_model_path', + default='', type=str, metavar='PATH', + help='path to checkpoint to load weights from (excluding other fields) (experimental)') + load_checkpoint_group.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') + load_checkpoint_group.add_argument('--reset-optimizer', action='store_true', + help='Flag to override optimizer if resumed from checkpoint. This will reset epochs count.') + parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') - parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), help='collect activation statistics on phases: train, valid, and/or test' ' (WARNING: this slows down training)') diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml index 06644da5de7cd594fc29713b5890f0fd7203e4b7..e77b980142ac5dbc298163fdacbce67a321cae4e 100755 --- a/examples/network_surgery/resnet20.network_surgery.yaml +++ b/examples/network_surgery/resnet20.network_surgery.yaml @@ -31,7 +31,7 @@ # Total sparsity: 69.1% # # of parameters: 83,671 # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic --validation-split=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml --vs=0 --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --masks-sparsity --num-best-scores=5 # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -127,13 +127,13 @@ policies: mini_batch_pruning_frequency: 1 mask_on_forward_only: True # use_double_copies: True - starting_epoch: 180 - ending_epoch: 280 + starting_epoch: 0 + ending_epoch: 100 frequency: 1 - lr_scheduler: instance_name: training_lr - starting_epoch: 225 + starting_epoch: 0 ending_epoch: 400 frequency: 1 diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml index cf6c7220d39d85ce95e1f933174ff774beeebbd7..c12f9f01f5a456ba80e113711e76f6a04bb9aea2 100755 --- a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml +++ b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml @@ -14,7 +14,7 @@ # Total MACs: 78,856,832 # # -# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -167,35 +167,34 @@ lr_schedulers: policies: - pruner: instance_name: filter_pruner_60 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_50 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_30 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_10 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - extension: instance_name: net_thinner - epochs: [200] - #epochs: [181] + epochs: [20] - lr_scheduler: instance_name: exp_finetuning_lr - starting_epoch: 190 + starting_epoch: 10 ending_epoch: 300 frequency: 1 diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml index 4f65e441bb6afa19cf1b8de8ece040fed2272c0e..bf6d078200696cece0291a7088b488c997e04dd5 100755 --- a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml +++ b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml @@ -14,7 +14,7 @@ # Total MACs: 67,797,632 # # -# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -168,34 +168,34 @@ lr_schedulers: policies: - pruner: instance_name: filter_pruner_60 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_50 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_30 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - pruner: instance_name: filter_pruner_10 - starting_epoch: 181 - ending_epoch: 200 + starting_epoch: 1 + ending_epoch: 20 frequency: 2 - extension: instance_name: net_thinner - epochs: [200] + epochs: [20] - lr_scheduler: instance_name: exp_finetuning_lr - starting_epoch: 190 + starting_epoch: 10 ending_epoch: 300 frequency: 1 diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml index 4d934c5bd8867809d2aebaf240d6cb9853e10cbc..68c760456eaee8ff59ad721ca261b0fe3e231215 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml @@ -5,7 +5,7 @@ # However, instead of one-shot filter ranking and pruning, we perform one-shot channel ranking and # pruning, using L1-magnitude of the structures. # -# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0 # # Baseline results: # Top1: 92.850 Top5: 99.780 Loss: 0.464 @@ -164,26 +164,26 @@ lr_schedulers: policies: - pruner: instance_name: filter_pruner_70 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_60 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_40 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_20 - epochs: [180] + epochs: [0] - extension: instance_name: net_thinner - epochs: [180] + epochs: [0] - lr_scheduler: instance_name: exp_finetuning_lr - starting_epoch: 190 + starting_epoch: 10 ending_epoch: 300 frequency: 1 diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml index d4dce9622f98f6d4d664f0a241c336f52c4a1a80..e52adf6700b90cba9fb89617fcb751f9dcb9a054 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml @@ -14,7 +14,7 @@ # You may either train this model from scratch, or download it from the link below. # https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar # -# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0 # # Results: 62.7% of the original convolution MACs (when calculated using direct convolution) # @@ -175,26 +175,26 @@ lr_schedulers: policies: - pruner: instance_name: filter_pruner_60 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_50 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_30 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_10 - epochs: [180] + epochs: [0] - extension: instance_name: net_thinner - epochs: [180] + epochs: [0] - lr_scheduler: instance_name: exp_finetuning_lr - starting_epoch: 190 + starting_epoch: 10 ending_epoch: 300 frequency: 1 diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml index 6e00522e43dd568918e08625ede4baba212929b8..58ac2e87b66f9963df2870225789d0cf156dcdf8 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml @@ -15,7 +15,7 @@ # You may either train this model from scratch, or download it from the link below. # https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar # -# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0 # # Results: 53.9% (1.85x) of the original convolution MACs (when calculated using direct convolution) # @@ -177,26 +177,26 @@ lr_schedulers: policies: - pruner: instance_name: filter_pruner_70 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_60 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_40 - epochs: [180] + epochs: [0] - pruner: instance_name: filter_pruner_20 - epochs: [180] + epochs: [0] - extension: instance_name: net_thinner - epochs: [180] + epochs: [0] - lr_scheduler: instance_name: exp_finetuning_lr - starting_epoch: 190 + starting_epoch: 10 ending_epoch: 300 frequency: 1 diff --git a/examples/ssl/ssl_channels-removal_finetuning.yaml b/examples/ssl/ssl_channels-removal_finetuning.yaml index 549629cf20cd4a102b1610fd29c6ea46b5e6fa59..89f65b88b168cc3a11913ac236fb04aa96e0a3e7 100755 --- a/examples/ssl/ssl_channels-removal_finetuning.yaml +++ b/examples/ssl/ssl_channels-removal_finetuning.yaml @@ -3,7 +3,7 @@ # # We save the output (i.e. checkpoint.pth.tar) in ../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20_finetuned.pth.tar # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar --reset-optimizer # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -58,6 +58,6 @@ lr_schedulers: policies: - lr_scheduler: instance_name: training_lr - starting_epoch: 45 + starting_epoch: 0 ending_epoch: 300 frequency: 1 diff --git a/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml b/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml index 195b09865815e871343407f6e14c74ff4eb8729b..4e96499b9000556c64c201d8568985d8746f8062 100755 --- a/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml +++ b/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml @@ -7,7 +7,7 @@ # # Total MACs: 22,583,936 == 55.3% compute density # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=<enter-your-model> +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning_x1.8.yaml --reset-optimizer --resume-from=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ @@ -66,6 +66,6 @@ lr_schedulers: policies: - lr_scheduler: instance_name: training_lr - starting_epoch: 45 + starting_epoch: 0 ending_epoch: 300 frequency: 1 diff --git a/examples/ssl/ssl_filter-removal_training.yaml b/examples/ssl/ssl_filter-removal_training.yaml index 84c30e861bca8c8f147c5d124947e6804471dec7..7dade219c405e4e2468d3d8c12f58e015b28d28e 100755 --- a/examples/ssl/ssl_filter-removal_training.yaml +++ b/examples/ssl/ssl_filter-removal_training.yaml @@ -9,7 +9,7 @@ # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_filter-removal_training.yaml -j=1 --deterministic --name="filters" # # To fine-tune: -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=... +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml --reset-optimizer --resume-from=... # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ diff --git a/examples/ssl/vgg16_cifar_ssl_channels_training.yaml b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml index 9b7155bfab9b1355b207b66230c7744ca03e7c2b..736bc8b523781a2df3520352581e083ada00c329 100755 --- a/examples/ssl/vgg16_cifar_ssl_channels_training.yaml +++ b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml @@ -5,7 +5,7 @@ # time python3 compress_classifier.py --arch vgg16_cifar ../../../data.cifar10 -p=50 --lr=0.05 --epochs=180 --compress=../ssl/vgg16_cifar_ssl_channels_training.yaml -j=1 --deterministic # # The results below are from the SSL training session, and you can follow-up with some fine-tuning: -# time python3 compress_classifier.py --arch vgg16_cifar ../../../data.cifar10 --resume=checkpoint.vgg16_cifar.pth.tar --lr=0.01 --epochs=20 +# time python3 compress_classifier.py --arch vgg16_cifar ../../../data.cifar10 --resume-from=checkpoint.vgg16_cifar.pth.tar --reset-optimizer --lr=0.01 --epochs=20 # ==> Top1: 91.010 Top5: 99.480 Loss: 0.513 # # Parameters: @@ -74,14 +74,13 @@ extensions: policies: - lr_scheduler: instance_name: training_lr - starting_epoch: 45 + starting_epoch: 0 ending_epoch: 300 frequency: 1 -# After completeing the regularization, we perform network thinning and exit. - extension: instance_name: net_thinner - epochs: [179] + epochs: [0] - regularizer: instance_name: Channels_groups_regularizer diff --git a/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar b/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..d15a84f847090d934bf91d19736ff577a2752476 Binary files /dev/null and b/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar differ diff --git a/tests/test_infra.py b/tests/test_infra.py index 2d0524f5bb77cb9d50195e304b1a1d4a1c9e4e1e..478ad3e1b9397d7f3daa1d1d6f22e26130159b24 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -28,7 +28,7 @@ except ImportError: sys.path.append(module_path) import distiller import distiller -from distiller.apputils import save_checkpoint, load_checkpoint +from distiller.apputils import save_checkpoint, load_checkpoint, load_lean_checkpoint from distiller.models import create_model import pretrainedmodels @@ -66,30 +66,65 @@ def test_create_model_pretrainedmodels(): model = create_model(False, 'imagenet', 'no_such_model!') +def _is_similar_param_groups(opt_a, opt_b): + for k in opt_a['param_groups'][0]: + val_a = opt_a['param_groups'][0][k] + val_b = opt_b['param_groups'][0][k] + if (val_a != val_b) and (k != 'params'): + return False + return True + + def test_load(): logger = logging.getLogger('simple_example') logger.setLevel(logging.INFO) - model = create_model(False, 'cifar10', 'resnet20_cifar') - model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') + checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar' + src_optimizer_state_dict = torch.load(checkpoint_filename)['optimizer_state_dict'] + + model = create_model(False, 'cifar10', 'resnet20_cifar', 0) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint( + model, checkpoint_filename) assert compression_scheduler is not None - assert start_epoch == 180 + assert optimizer is not None, 'Failed to load the optimizer' + if not _is_similar_param_groups(src_optimizer_state_dict, optimizer.state_dict()): + assert src_optimizer_state_dict == optimizer.state_dict() # this will always fail + assert start_epoch == 1 + + +def test_load_state_dict_implicit(): + # prepare lean checkpoint + state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') + + with tempfile.NamedTemporaryFile() as tmpfile: + torch.save({'state_dict': state_dict_arrays}, tmpfile.name) + model = create_model(False, 'cifar10', 'resnet20_cifar') + with pytest.raises(KeyError): + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) -def test_load_state_dict(): +def test_load_lean_checkpoint_1(): # prepare lean checkpoint state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') with tempfile.NamedTemporaryFile() as tmpfile: torch.save({'state_dict': state_dict_arrays}, tmpfile.name) model = create_model(False, 'cifar10', 'resnet20_cifar') - model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint( + model, tmpfile.name, lean_checkpoint=True) - assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0 assert compression_scheduler is None + assert optimizer is None assert start_epoch == 0 +def test_load_lean_checkpoint_2(): + checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar' + + model = create_model(False, 'cifar10', 'resnet20_cifar', 0) + model = load_lean_checkpoint(model, checkpoint_filename) + + def test_load_dumb_checkpoint(): # prepare lean checkpoint state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') @@ -98,24 +133,42 @@ def test_load_dumb_checkpoint(): torch.save(state_dict_arrays, tmpfile.name) model = create_model(False, 'cifar10', 'resnet20_cifar') with pytest.raises(ValueError): - model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) def test_load_negative(): with pytest.raises(FileNotFoundError): model = create_model(False, 'cifar10', 'resnet20_cifar') - model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, + 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') def test_load_gpu_model_on_cpu(): # Issue #148 CPU_DEVICE_ID = -1 + CPU_DEVICE_NAME = 'cpu' + checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar' + model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) - model, compression_scheduler, start_epoch = load_checkpoint(model, - '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') + model, compression_scheduler, optimizer, start_epoch = load_checkpoint( + model, checkpoint_filename) + assert compression_scheduler is not None - assert start_epoch == 180 - assert distiller.model_device(model) == 'cpu' + assert optimizer is not None + assert distiller.utils.optimizer_device_name(optimizer) == CPU_DEVICE_NAME + assert start_epoch == 1 + assert distiller.model_device(model) == CPU_DEVICE_NAME + + +def test_load_gpu_model_on_cpu_lean_checkpoint(): + CPU_DEVICE_ID = -1 + CPU_DEVICE_NAME = 'cpu' + checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar' + + model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) + model = load_lean_checkpoint(model, checkpoint_filename, + model_device=CPU_DEVICE_NAME) + assert distiller.model_device(model) == CPU_DEVICE_NAME def test_load_gpu_model_on_cpu_with_thinning(): @@ -137,13 +190,10 @@ def test_load_gpu_model_on_cpu_with_thinning(): distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None) assert hasattr(gpu_model, 'thinning_recipes') scheduler = distiller.CompressionScheduler(gpu_model) - save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None) + save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None, + dir='checkpoints') CPU_DEVICE_ID = -1 cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) - load_checkpoint(cpu_model, "checkpoint.pth.tar") + load_lean_checkpoint(cpu_model, "checkpoints/checkpoint.pth.tar") assert distiller.model_device(cpu_model) == 'cpu' - - -if __name__ == '__main__': - test_load_gpu_model_on_cpu() diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 8b9c4ddf45e088a87d5b3cfa0cb0e033347eb811..443e45240083ad88108ebd4940764244c6a9c6e9 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -23,7 +23,7 @@ import distiller import common import pytest from distiller.models import create_model -from distiller.apputils import save_checkpoint, load_checkpoint +from distiller.apputils import save_checkpoint, load_checkpoint, load_lean_checkpoint # Logging configuration logging.basicConfig(level=logging.INFO) @@ -296,7 +296,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): conv2 = common.find_module_by_name(model_2, pair[1]) assert conv2 is not None with pytest.raises(KeyError): - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') compression_scheduler = distiller.CompressionScheduler(model) hasattr(model, 'thinning_recipes') @@ -304,13 +304,13 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): # (2) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done") # (3) save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler) - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done 2")