diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/test_basic.py b/tests/test_basic.py index 5399d70bfa04964905d940bfc4bcea35bd28cee5..19f742a5d1b759a045c760c584236618e4198846 100755 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -15,6 +15,11 @@ # import torch +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) import distiller def test_sparsity(): diff --git a/tests/test_infra.py b/tests/test_infra.py index d1888a2ba0d147bad244163b401079d548315c56..b27d201532dfaaaecf308215de32453240a67421 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -14,11 +14,17 @@ # limitations under the License. # -def test_load(): - from models import create_model - from apputils import load_checkpoint - import logging +import logging +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) + +from models import create_model +from apputils import load_checkpoint +def test_load(): logger = logging.getLogger('simple_example') logger.setLevel(logging.INFO)