Skip to content
Snippets Groups Projects
Commit 1f48fa64 authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

Fix broken load test (#245)

In a former commit, distiller accepts checkpoints that do not contain
'optimizer' argument. However, this change was not reflected in the
relevant test.
parent 09d2eea3
No related branches found
No related tags found
No related merge requests found
...@@ -89,8 +89,11 @@ def test_load_state_dict_implicit(): ...@@ -89,8 +89,11 @@ def test_load_state_dict_implicit():
with tempfile.NamedTemporaryFile() as tmpfile: with tempfile.NamedTemporaryFile() as tmpfile:
torch.save({'state_dict': state_dict_arrays}, tmpfile.name) torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar') model = create_model(False, 'cifar10', 'resnet20_cifar')
with pytest.raises(KeyError): model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
assert compression_scheduler is None
assert optimizer is None
assert start_epoch == 0
def test_load_lean_checkpoint_1(): def test_load_lean_checkpoint_1():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment