diff --git a/tests/test_infra.py b/tests/test_infra.py
index 720707c4cb2928815e991095a2dc511272b54b13..c90949e681a9eba7b577d864fb05989d9a084f78 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -89,8 +89,11 @@ def test_load_state_dict_implicit():
     with tempfile.NamedTemporaryFile() as tmpfile:
         torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
         model = create_model(False, 'cifar10', 'resnet20_cifar')
-        with pytest.raises(KeyError):
-            model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
+        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
+
+    assert compression_scheduler is None
+    assert optimizer is None
+    assert start_epoch == 0
 
 
 def test_load_lean_checkpoint_1():