diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index c35c31b516639f5c439b07bd3ddedbe028de98f8..9b45391d860bef3c85061a9d690691e62205503a 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -145,9 +145,21 @@ def config_logger(experiment_name): msglogger.info('Log file for this run: ' + os.path.realpath(log_filename)) return msglogger +def check_pytorch_version(): + if torch.__version__ < '0.4.0': + print("\nNOTICE:") + print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to " + "PyTorch API changes which are not backward-compatible.\n" + "Please install PyTorch 0.4.0 or its derivative.\n" + "If you are using a virtual environment, do not forget to update it:\n" + " 1. Deactivate the old environment\n" + " 2. Install the new environment\n" + " 3. Activate the new environment") + exit(1) def main(): global msglogger + check_pytorch_version() args = parser.parse_args() msglogger = config_logger(args.name)