From 961bfc8913c9c7ea43e7011a90493916d21b598b Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 21 Aug 2019 13:47:25 +0300 Subject: [PATCH] test_pruning.py: adjust test after relaxing thinning checkpoint loading --- tests/test_pruning.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 277fa95..0239321 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -317,7 +317,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): # We save 3 times, and load twice, to make sure to cover some corner cases: # - Make sure that after loading, the model still has hold of the thinning recipes # - Make sure that after a 2nd load, there no problem loading (in this case, the - # - tensors are already thin, so this is a new flow) + # tensors are already thin, so this is a new flow) # (1) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel) @@ -325,14 +325,13 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): model_2(dummy_input) conv2 = common.find_module_by_name(model_2, pair[1]) assert conv2 is not None - with pytest.raises(KeyError): - model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') - compression_scheduler = distiller.CompressionScheduler(model) - hasattr(model, 'thinning_recipes') + model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') run_forward_backward(model, optimizer, dummy_input) # (2) + compression_scheduler = distiller.CompressionScheduler(model) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') -- GitLab