diff --git a/distiller/__init__.py b/distiller/__init__.py index 458b824af965d332fc98c59ebeee45af631424a8..e68701064123cc83fe1c6c59e2d18ad5af92b7fb 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -107,13 +107,13 @@ def model_find_module(model, module_to_find): def check_pytorch_version(): from pkg_resources import parse_version - if parse_version(torch.__version__) < parse_version('1.1.0'): + required = '1.3.1' + actual = torch.__version__ + if parse_version(actual) < parse_version(required): msg = "\n\nWRONG PYTORCH VERSION\n"\ - "The Distiller \'master\' branch now requires at least PyTorch version 1.1.0 due to "\ - "PyTorch API changes which are not backward-compatible. Version detected is {}.\n"\ - "To make sure PyTorch and all other dependencies are installed with their correct versions, " \ - "go to the Distiller repo root directory and run:\n\n"\ - "pip install -e .\n".format(torch.__version__) + "Required: {}\n" \ + "Installed: {}\n"\ + "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual) raise RuntimeError(msg)