Skip to content
Snippets Groups Projects
Unverified Commit b64be690 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

More robust handling of data-parallel/serial graphs (#27)

Remove the complicated logic trying to handle data-parallel models as
serially-processed models, and vice versa.

*Function distiller.utils.make_non_parallel_copy() does the heavy lifting of
replacing  all instances of nn.DataParallel in a model with instances of
DoNothingModuleWrapper.
The DoNothingModuleWrapper wrapper does nothing but forward to the
wrapped module.  This is a trick we use to transform a data-parallel model
to a serial-processed model.

*SummaryGraph uses a copy of the model after the model is processed by
distiller.make_non_parallel_copy() which renders the model non-data-parallel.

*The same goes for model_performance_summary()

*Model inputs are explicitly placed on the Cuda device, since now all models are
Executed on the CPU.  Previously, if a model was not created using
nn.DataParallel, then the model was not explicitly placed on the Cuda device.

*The logic in distiller.CompressionScheduler that attempted to load a
model parallel model and process it serially, or load a serial model and
process it data-parallel, was removed.  This removes a lot of fuzziness and makes
the code more robust: we do not needlessly try to be heroes.

* model summaries - remove pytorch 0.4 warning

* create_model: remove redundant .cuda() call

* Tests: support both parallel and serial tests
parent 51a7df35
No related branches found
No related tags found
No related merge requests found
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment