Skip to content
Snippets Groups Projects
Commit ba05f6cf authored by Neta Zmora's avatar Neta Zmora
Browse files

CPU support: fix the case of loading a thinned GPU-model on the CPU

This commit fixes (and adds a test) for the case that we with to load
a thinned GPU checkpoint onto the CPU.
parent 0ae07549
No related branches found
No related tags found
No related merge requests found
...@@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): ...@@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
msglogger.info("=> loading checkpoint %s", chkpt_file) msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage)
msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint))) msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
if 'state_dict' not in checkpoint: if 'state_dict' not in checkpoint:
...@@ -121,7 +121,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): ...@@ -121,7 +121,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
# Cache the recipes in case we need them later # Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes'] model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys: if normalize_dataparallel_keys:
model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()} model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model, distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict, compression_scheduler.zeros_mask_dict,
model.thinning_recipes) model.thinning_recipes)
......
...@@ -61,7 +61,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', ...@@ -61,7 +61,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
'ChannelRemover', 'remove_channels', 'ChannelRemover', 'remove_channels',
'FilterRemover', 'remove_filters', 'FilterRemover', 'remove_filters',
'find_nonzero_channels', 'find_nonzero_channels_list', 'find_nonzero_channels', 'find_nonzero_channels_list',
'execute_thinning_recipes_list'] 'execute_thinning_recipes_list', 'get_normalized_recipe']
def create_graph(dataset, arch): def create_graph(dataset, arch):
...@@ -77,6 +77,12 @@ def create_graph(dataset, arch): ...@@ -77,6 +77,12 @@ def create_graph(dataset, arch):
return SummaryGraph(model, dummy_input) return SummaryGraph(model, dummy_input)
def get_normalized_recipe(recipe):
new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()},
parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()})
return new_recipe
def param_name_2_layer_name(param_name): def param_name_2_layer_name(param_name):
return param_name[:-len('weights')] return param_name[:-len('weights')]
......
...@@ -26,7 +26,7 @@ if module_path not in sys.path: ...@@ -26,7 +26,7 @@ if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
import distiller import distiller
from apputils import load_checkpoint from apputils import save_checkpoint, load_checkpoint
from models import create_model from models import create_model
...@@ -39,6 +39,7 @@ def test_load(): ...@@ -39,6 +39,7 @@ def test_load():
assert compression_scheduler is not None assert compression_scheduler is not None
assert start_epoch == 180 assert start_epoch == 180
def test_load_state_dict(): def test_load_state_dict():
# prepare lean checkpoint # prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
...@@ -52,6 +53,7 @@ def test_load_state_dict(): ...@@ -52,6 +53,7 @@ def test_load_state_dict():
assert compression_scheduler is None assert compression_scheduler is None
assert start_epoch == 0 assert start_epoch == 0
def test_load_dumb_checkpoint(): def test_load_dumb_checkpoint():
# prepare lean checkpoint # prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
...@@ -62,6 +64,7 @@ def test_load_dumb_checkpoint(): ...@@ -62,6 +64,7 @@ def test_load_dumb_checkpoint():
with pytest.raises(ValueError): with pytest.raises(ValueError):
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
def test_load_negative(): def test_load_negative():
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
model = create_model(False, 'cifar10', 'resnet20_cifar') model = create_model(False, 'cifar10', 'resnet20_cifar')
...@@ -69,12 +72,42 @@ def test_load_negative(): ...@@ -69,12 +72,42 @@ def test_load_negative():
def test_load_gpu_model_on_cpu(): def test_load_gpu_model_on_cpu():
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1) # Issue #148
CPU_DEVICE_ID = -1
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
model, compression_scheduler, start_epoch = load_checkpoint(model, model, compression_scheduler, start_epoch = load_checkpoint(model,
'../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
assert compression_scheduler is not None assert compression_scheduler is not None
assert start_epoch == 180 assert start_epoch == 180
assert distiller.model_device(model) == 'cpu' assert distiller.model_device(model) == 'cpu'
def test_load_gpu_model_on_cpu_with_thinning():
# Issue #148
# 1. create a GPU model and remove 50% of the filters in one of the layers (thninning)
# 2. save the thinned model in a checkpoint file
# 3. load the checkpoint and place it on the CPU
CPU_DEVICE_ID = -1
gpu_model = create_model(False, 'cifar10', 'resnet20_cifar')
conv_pname = "module.layer1.0.conv1.weight"
conv_p = distiller.model_find_param(gpu_model, conv_pname)
pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters",
desired_sparsity=0.5, weights=conv_pname)
zeros_mask_dict = distiller.create_model_masks_dict(gpu_model)
pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None)
# Use the mask to prune
zeros_mask_dict[conv_pname].apply_mask(conv_p)
distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
assert hasattr(gpu_model, 'thinning_recipes')
scheduler = distiller.CompressionScheduler(gpu_model)
save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None)
CPU_DEVICE_ID = -1
cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
load_checkpoint(cpu_model, "checkpoint.pth.tar")
assert distiller.model_device(cpu_model) == 'cpu'
if __name__ == '__main__': if __name__ == '__main__':
test_load_gpu_model_on_cpu() test_load_gpu_model_on_cpu()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment