From ba653d9ac4875d66f930ff26024414f05fa6920f Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 16 May 2018 15:38:35 +0300
Subject: [PATCH] Check if correct version of PyTorch is installed.

The 'master' branch now uses PyTorch 0.4, which has API changes that
are not backward compatible with PyTorch 0.3.

After we've upgraded Distiller's internal implementation to be
compatible with PyTorch 0.4, we've added a check that you are using
the correct PyTorch version.

Note that we only perform this check in the sample image classifier
compression application.
---
 .../classifier_compression/compress_classifier.py    | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index c35c31b..9b45391 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -145,9 +145,21 @@ def config_logger(experiment_name):
     msglogger.info('Log file for this run: ' + os.path.realpath(log_filename))
     return msglogger
 
+def check_pytorch_version():
+    if torch.__version__ < '0.4.0':
+        print("\nNOTICE:")
+        print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
+              "PyTorch API changes which are not backward-compatible.\n"
+              "Please install PyTorch 0.4.0 or its derivative.\n"
+              "If you are using a virtual environment, do not forget to update it:\n"
+              "  1. Deactivate the old environment\n"
+              "  2. Install the new environment\n"
+              "  3. Activate the new environment")
+        exit(1)
 
 def main():
     global msglogger
+    check_pytorch_version()
     args = parser.parse_args()
     msglogger = config_logger(args.name)
 
-- 
GitLab