diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index 2cfbbba8d2878648cf89d90d909a1b1fa035ada7..fc90e8877deb6f8da54e491aea29f5e342fca328 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -62,10 +62,15 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' fullpath_best = os.path.join(dir, filename_best) - checkpoint = {} - checkpoint['epoch'] = epoch - checkpoint['arch'] = arch - checkpoint['state_dict'] = model.state_dict() + checkpoint = {'epoch': epoch, 'state_dict': model.state_dict(), 'arch': arch} + try: + checkpoint['is_parallel'] = model.is_parallel + checkpoint['dataset'] = model.dataset + if not arch: + checkpoint['arch'] = model.arch + except NameError: + pass + if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() checkpoint['optimizer_type'] = type(optimizer) @@ -105,7 +110,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None, """Load a pytorch training checkpoint. Args: - model: the pytorch model to which we will load the parameters + model: the pytorch model to which we will load the parameters. You can + specify model=None if the checkpoint contains enough metadata to infer + the model. The order of the arguments is misleading and clunky, and is + kept this way for backward compatibility. chkpt_file: the checkpoint file lean_checkpoint: if set, read into model only 'state_dict' field optimizer: [deprecated argument] @@ -159,8 +167,24 @@ def load_checkpoint(model, chkpt_file, optimizer=None, msglogger.warning('Optimizer could not be loaded from checkpoint.') return None + def _create_model_from_ckpt(): + try: + return distiller.models.create_model(False, checkpoint['dataset'], checkpoint['arch'], + checkpoint['is_parallel'], device_ids=None) + except KeyError: + return None + + def _sanity_check(): + try: + if model.arch != checkpoint["arch"]: + raise ValueError("The model architecture does not match the checkpoint architecture") + except (NameError, KeyError): + # One of the values is missing so we can't perform the comparison + pass + if not os.path.isfile(chkpt_file): raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) + assert optimizer == None, "argument optimizer is deprecated and must be set to None" msglogger.info("=> loading checkpoint %s", chkpt_file) checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage) @@ -171,9 +195,14 @@ def load_checkpoint(model, chkpt_file, optimizer=None, if 'state_dict' not in checkpoint: raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'") + if not model: + model = _create_model_from_ckpt() + if not model: + raise ValueError("You didn't provide a model, and the checkpoint doesn't contain" + "enough information to create one") + checkpoint_epoch = checkpoint.get('epoch', None) start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 - compression_scheduler = None normalize_dataparallel_keys = False if 'compression_sched' in checkpoint: @@ -217,4 +246,5 @@ def load_checkpoint(model, chkpt_file, optimizer=None, optimizer = _load_optimizer() msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), e=checkpoint_epoch)) + _sanity_check() return model, compression_scheduler, optimizer, start_epoch diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 01f78c3760a10c4ff8e45659b6eb2f016af11a4e..83e653938e833775b17317a834055e0e2b7ac423 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -111,14 +111,19 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): arch, dataset)) if torch.cuda.is_available() and device_ids != -1: device = 'cuda' - if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: - model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) - elif parallel: - model = torch.nn.DataParallel(model, device_ids=device_ids) + if parallel: + if arch.startswith('alexnet') or arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) + else: + model = torch.nn.DataParallel(model, device_ids=device_ids) else: device = 'cpu' + # Cache some attributes which describe the model _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene) + model.arch = arch + model.dataset = dataset + model.is_parallel = parallel return model.to(device) diff --git a/tests/test_infra.py b/tests/test_infra.py index c778e514f713abfbd4fd3c1ebb76d60c4d8f4790..2a266de82ed97b7252750ea5f519682d964a2150 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import logging import tempfile - import torch import pytest import distiller @@ -182,8 +182,8 @@ def test_load_gpu_model_on_cpu_with_thinning(): distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None) assert hasattr(gpu_model, 'thinning_recipes') scheduler = distiller.CompressionScheduler(gpu_model) - save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None, - dir='checkpoints') + save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, + scheduler=scheduler, optimizer=None, dir='checkpoints') CPU_DEVICE_ID = -1 cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) @@ -269,3 +269,30 @@ def test_get_dummy_input(device): check_shape_device(t[0], shape[0], expected_device) check_shape_device(t[1][0], shape[1][0], expected_device) check_shape_device(t[1][1], shape[1][1], expected_device) + + +def test_load_checkpoint_without_model(): + checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar' + # Load a checkpoint w/o specifying the model: this should fail because the loaded + # checkpoint is old and does not have the required metadata to create a model. + with pytest.raises(ValueError): + load_checkpoint(model=None, chkpt_file=checkpoint_filename) + + for model_device in (None, 'cuda', 'cpu'): + # Now we create a new model, save a checkpoint, and load it w/o specifying the model. + # This should succeed because the checkpoint has enough metadata to create model. + model = create_model(False, 'cifar10', 'resnet20_cifar', 0) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, checkpoint_filename) + save_checkpoint(epoch=0, arch='resnet20_cifar', model=model, name='eraseme', + scheduler=compression_scheduler, optimizer=None, dir='checkpoints') + temp_checkpoint = os.path.join("checkpoints", "eraseme_checkpoint.pth.tar") + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model=None, + chkpt_file=temp_checkpoint, + model_device=model_device) + assert compression_scheduler is not None + assert optimizer is None + assert start_epoch == 1 + assert model + assert model.arch == "resnet20_cifar" + assert model.dataset == "cifar10" + os.remove(temp_checkpoint) \ No newline at end of file