diff --git a/tests/test_infra.py b/tests/test_infra.py index 720707c4cb2928815e991095a2dc511272b54b13..c90949e681a9eba7b577d864fb05989d9a084f78 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -89,8 +89,11 @@ def test_load_state_dict_implicit(): with tempfile.NamedTemporaryFile() as tmpfile: torch.save({'state_dict': state_dict_arrays}, tmpfile.name) model = create_model(False, 'cifar10', 'resnet20_cifar') - with pytest.raises(KeyError): - model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) + + assert compression_scheduler is None + assert optimizer is None + assert start_epoch == 0 def test_load_lean_checkpoint_1():