Skip to content
Snippets Groups Projects
  • Neta Zmora's avatar
    1210f412
    Fix issue #148 + refactor load_checkpoint.py (#153) · 1210f412
    Neta Zmora authored
    The root-cause of issue #148 is that DataParallel modules cannot execute on the CPU,
    on machines that have both CPUs and GPUs.
    Therefore, we don’t use DataParallel for models loaded for the CPUs, but we do wrap
    the models with DataParallel when loaded on the GPUs (to make them run faster).
    The names of module keys saved in a checkpoint file depend if the modules are wrapped
    by a DataParallel module or not.  So loading a checkpoint that ran on the GPU onto a
    CPU-model (and vice-versa) will fail on the keys.
    This is all PyTorch and despite the community asking for a fix -
    e.g. https://github.com/pytorch/pytorch/issues/7457 - it is still pending.
    
    This commit contains code to catch key errors when loading a GPU-generated model
    (i.e. with DataParallel) onto a CPU, and convert the names of the keys.
    
    This PR also merges refactoring to load_chackpoint.py done by @barrh, who also added
    a test to further test loading checkpoints.
    1210f412
    History
    Fix issue #148 + refactor load_checkpoint.py (#153)
    Neta Zmora authored
    The root-cause of issue #148 is that DataParallel modules cannot execute on the CPU,
    on machines that have both CPUs and GPUs.
    Therefore, we don’t use DataParallel for models loaded for the CPUs, but we do wrap
    the models with DataParallel when loaded on the GPUs (to make them run faster).
    The names of module keys saved in a checkpoint file depend if the modules are wrapped
    by a DataParallel module or not.  So loading a checkpoint that ran on the GPU onto a
    CPU-model (and vice-versa) will fail on the keys.
    This is all PyTorch and despite the community asking for a fix -
    e.g. https://github.com/pytorch/pytorch/issues/7457 - it is still pending.
    
    This commit contains code to catch key errors when loading a GPU-generated model
    (i.e. with DataParallel) onto a CPU, and convert the names of the keys.
    
    This PR also merges refactoring to load_chackpoint.py done by @barrh, who also added
    a test to further test loading checkpoints.