From 5e198621b131bbc6c324ccdd3833f8509c0233e0 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 25 Apr 2018 12:00:19 +0300 Subject: [PATCH] Fixed test execution in new environment --- tests/__init__.py | 0 tests/test_basic.py | 5 +++++ tests/test_infra.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 4 deletions(-) delete mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_basic.py b/tests/test_basic.py index 5399d70..19f742a 100755 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -15,6 +15,11 @@ # import torch +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) import distiller def test_sparsity(): diff --git a/tests/test_infra.py b/tests/test_infra.py index d1888a2..b27d201 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -14,11 +14,17 @@ # limitations under the License. # -def test_load(): - from models import create_model - from apputils import load_checkpoint - import logging +import logging +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) + +from models import create_model +from apputils import load_checkpoint +def test_load(): logger = logging.getLogger('simple_example') logger.setLevel(logging.INFO) -- GitLab