Skip to content
Snippets Groups Projects
Commit 1e7e9835 authored by Neta Zmora's avatar Neta Zmora
Browse files
parents d8c97cdd 1f48fa64
No related branches found
No related tags found
No related merge requests found
...@@ -89,8 +89,11 @@ def test_load_state_dict_implicit(): ...@@ -89,8 +89,11 @@ def test_load_state_dict_implicit():
with tempfile.NamedTemporaryFile() as tmpfile: with tempfile.NamedTemporaryFile() as tmpfile:
torch.save({'state_dict': state_dict_arrays}, tmpfile.name) torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar') model = create_model(False, 'cifar10', 'resnet20_cifar')
with pytest.raises(KeyError): model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
assert compression_scheduler is None
assert optimizer is None
assert start_epoch == 0
def test_load_lean_checkpoint_1(): def test_load_lean_checkpoint_1():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment