diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 69e2a7cc4fafdb049f2d3e5fb144614fb42a99e5..10c1ed97b2b98752bce500cbb8e99a9cb441f122 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -22,6 +22,7 @@ a pruning session, or for querying the pruning schedule of a sparse model. import os import shutil +from errno import ENOENT import logging import torch import distiller @@ -44,8 +45,8 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, dir: directory in which to save the checkpoint """ if not os.path.isdir(dir): - msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir))) - exit(1) + raise IOError(ENOENT, 'Checkpoint directory does not exist at', os.path.abspath(dir)) + filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' fullpath = os.path.join(dir, filename) msglogger.info("Saving checkpoint to: %s" % fullpath) @@ -120,5 +121,4 @@ def load_checkpoint(model, chkpt_file, optimizer=None): model.load_state_dict(checkpoint['state_dict']) return model, compression_scheduler, start_epoch else: - msglogger.info("Error: no checkpoint found at %s", chkpt_file) - exit(1) + raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) diff --git a/jupyter/alexnet_insights.ipynb b/jupyter/alexnet_insights.ipynb index 1d4e87a4e64aad12b884f58eb98e70fe2aec1faa..bedaaf0b50fd8ea17984afed6b5c9ad26fd2d4d0 100644 --- a/jupyter/alexnet_insights.ipynb +++ b/jupyter/alexnet_insights.ipynb @@ -89,7 +89,7 @@ "checkpoint_file = \"../examples/classifier_compression/alexnet.checkpoint.89.pth.tar\"\n", "try:\n", " load_checkpoint(epoch89_model, checkpoint_file);\n", - "except NameError as e:\n", + "except Exception as e:\n", " print(\"Did you forget to download the checkpoint file?\")\n", " raise e \n", " \n", @@ -131,7 +131,7 @@ "checkpoint_file = \"../examples/classifier_compression/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar\"\n", "try:\n", " load_checkpoint(reg2D_model, checkpoint_file);\n", - "except NameError as e:\n", + "except Exception as e:\n", " print(\"Did you forget to download the checkpoint file?\")\n", " raise e " ] diff --git a/tests/test_infra.py b/tests/test_infra.py index b27d201532dfaaaecf308215de32453240a67421..099598016c5368b19096379ab8ff592473acd2a1 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -14,9 +14,10 @@ # limitations under the License. # -import logging +import logging import os import sys +import pytest module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) @@ -32,4 +33,8 @@ def test_load(): model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') assert compression_scheduler is not None assert start_epoch == 180 - return True + +def test_load_negative(): + with pytest.raises(FileNotFoundError): + model = create_model(False, 'cifar10', 'resnet20_cifar') + model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')