From 1f48fa64131596b181ebd26a59d2679f7f877dee Mon Sep 17 00:00:00 2001 From: Bar <elhararb@gmail.com> Date: Mon, 6 May 2019 17:07:44 +0300 Subject: [PATCH] Fix broken load test (#245) In a former commit, distiller accepts checkpoints that do not contain 'optimizer' argument. However, this change was not reflected in the relevant test. --- tests/test_infra.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_infra.py b/tests/test_infra.py index 720707c..c90949e 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -89,8 +89,11 @@ def test_load_state_dict_implicit(): with tempfile.NamedTemporaryFile() as tmpfile: torch.save({'state_dict': state_dict_arrays}, tmpfile.name) model = create_model(False, 'cifar10', 'resnet20_cifar') - with pytest.raises(KeyError): - model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) + model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name) + + assert compression_scheduler is None + assert optimizer is None + assert start_epoch == 0 def test_load_lean_checkpoint_1(): -- GitLab