From 5e198621b131bbc6c324ccdd3833f8509c0233e0 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 25 Apr 2018 12:00:19 +0300
Subject: [PATCH] Fixed test execution in new environment

---
 tests/__init__.py   |  0
 tests/test_basic.py |  5 +++++
 tests/test_infra.py | 14 ++++++++++----
 3 files changed, 15 insertions(+), 4 deletions(-)
 delete mode 100644 tests/__init__.py

diff --git a/tests/__init__.py b/tests/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/test_basic.py b/tests/test_basic.py
index 5399d70..19f742a 100755
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -15,6 +15,11 @@
 #
 
 import torch
+import os
+import sys
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
 import distiller
 
 def test_sparsity():
diff --git a/tests/test_infra.py b/tests/test_infra.py
index d1888a2..b27d201 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -14,11 +14,17 @@
 # limitations under the License.
 #
 
-def test_load():
-    from models import create_model
-    from apputils import load_checkpoint
-    import logging
+import logging 
+import os
+import sys
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+
+from models import create_model
+from apputils import load_checkpoint
 
+def test_load():
     logger = logging.getLogger('simple_example')
     logger.setLevel(logging.INFO)
 
-- 
GitLab