Skip to content
Snippets Groups Projects
Commit ba653d9a authored by Neta Zmora's avatar Neta Zmora
Browse files

Check if correct version of PyTorch is installed.

The 'master' branch now uses PyTorch 0.4, which has API changes that
are not backward compatible with PyTorch 0.3.

After we've upgraded Distiller's internal implementation to be
compatible with PyTorch 0.4, we've added a check that you are using
the correct PyTorch version.

Note that we only perform this check in the sample image classifier
compression application.
parent bd946e68
No related branches found
No related tags found
No related merge requests found
......@@ -145,9 +145,21 @@ def config_logger(experiment_name):
msglogger.info('Log file for this run: ' + os.path.realpath(log_filename))
return msglogger
def check_pytorch_version():
if torch.__version__ < '0.4.0':
print("\nNOTICE:")
print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
"PyTorch API changes which are not backward-compatible.\n"
"Please install PyTorch 0.4.0 or its derivative.\n"
"If you are using a virtual environment, do not forget to update it:\n"
" 1. Deactivate the old environment\n"
" 2. Install the new environment\n"
" 3. Activate the new environment")
exit(1)
def main():
global msglogger
check_pytorch_version()
args = parser.parse_args()
msglogger = config_logger(args.name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment