diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 277fa959bc28906736b709454a10db6f0f0ce6b8..023932116effb2ced2fdf6bb04ed4f789762464e 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -317,7 +317,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): # We save 3 times, and load twice, to make sure to cover some corner cases: # - Make sure that after loading, the model still has hold of the thinning recipes # - Make sure that after a 2nd load, there no problem loading (in this case, the - # - tensors are already thin, so this is a new flow) + # tensors are already thin, so this is a new flow) # (1) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel) @@ -325,14 +325,13 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): model_2(dummy_input) conv2 = common.find_module_by_name(model_2, pair[1]) assert conv2 is not None - with pytest.raises(KeyError): - model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') - compression_scheduler = distiller.CompressionScheduler(model) - hasattr(model, 'thinning_recipes') + model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') run_forward_backward(model, optimizer, dummy_input) # (2) + compression_scheduler = distiller.CompressionScheduler(model) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes')