From 72ef9160b184f4d31d17e5d2fd59f312e1687740 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Mon, 8 Apr 2019 14:27:21 +0300
Subject: [PATCH] Removed sys.path modifications when importing distiller.
 (#224)

---
 examples/word_language_model/generate.py |  6 ------
 tests/common.py                          |  6 ------
 tests/test_basic.py                      |  5 -----
 tests/test_learning_rate.py              |  6 ------
 tests/test_loss.py                       |  5 -----
 tests/test_model_summary.py              |  5 -----
 tests/test_post_train_quant.py           |  3 ---
 tests/test_pruning.py                    |  2 --
 tests/test_quant_utils.py                |  5 -----
 tests/test_quantizer.py                  |  2 --
 tests/test_ranking.py                    | 10 +---------
 tests/test_thresholding.py               | 20 +++++++++++++++-----
 12 files changed, 16 insertions(+), 59 deletions(-)

diff --git a/examples/word_language_model/generate.py b/examples/word_language_model/generate.py
index 9beb06c..3849b69 100755
--- a/examples/word_language_model/generate.py
+++ b/examples/word_language_model/generate.py
@@ -18,12 +18,6 @@ import data
 # dependency on distiller code.
 # It's a bit ironic that PyTorch's docs advise against this kind of serialization,
 # while PyTorch's samples use it: https://pytorch.org/docs/master/notes/serialization.html
-import os
-import sys
-script_dir = os.path.dirname(__file__)
-module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 import distiller
 
 parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model')
diff --git a/tests/common.py b/tests/common.py
index a2283f1..792e033 100755
--- a/tests/common.py
+++ b/tests/common.py
@@ -13,13 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-import os
-import sys
 import torch
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 import distiller
 from distiller.models import create_model
 
diff --git a/tests/test_basic.py b/tests/test_basic.py
index 02091d0..6c06572 100755
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -15,12 +15,7 @@
 #
 
 import torch
-import os
-import sys
 import common
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 import distiller
 import distiller.models as models
 
diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py
index 42767bb..dd23914 100644
--- a/tests/test_learning_rate.py
+++ b/tests/test_learning_rate.py
@@ -14,13 +14,7 @@
 # limitations under the License.
 #
 
-import os
-import sys
 import pytest
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
-
 import torch
 from torch.optim import Optimizer
 from distiller.learning_rate import MultiStepMultiGammaLR
diff --git a/tests/test_loss.py b/tests/test_loss.py
index 6c197d0..27452f8 100644
--- a/tests/test_loss.py
+++ b/tests/test_loss.py
@@ -15,15 +15,10 @@
 #
 
 import torch
-import os
-import sys
 import torch.nn as nn
 from copy import deepcopy
 import pytest
 
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 from distiller import ScheduledTrainingPolicy, CompressionScheduler
 from distiller.policy import PolicyLoss, LossComponent
 
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index bb77de8..5053748 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -16,11 +16,6 @@
 
 import logging
 import torch
-import os
-import sys
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 import distiller
 import pytest
 import common  # common test code
diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index aa51f24..2414866 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -13,14 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-import os
 import pytest
 import torch
 import torch.testing
 from collections import OrderedDict
 
-module_path = os.path.abspath(os.path.join('..'))
 from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, \
     RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper
 import distiller.modules
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 443e452..42a6ab7 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -17,8 +17,6 @@ from collections import namedtuple
 import numpy as np
 import logging
 import torch
-import os
-import sys
 import distiller
 import common
 import pytest
diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py
index f613e50..fbc6e1a 100644
--- a/tests/test_quant_utils.py
+++ b/tests/test_quant_utils.py
@@ -16,11 +16,6 @@
 
 import torch
 import pytest
-import sys
-import os
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 from distiller.quantization import q_utils as qu
 
 
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index f406ae4..fdd47c7 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -15,13 +15,11 @@
 #
 
 import torch
-import os
 import torch.nn as nn
 from copy import deepcopy
 from collections import OrderedDict
 import pytest
 
-module_path = os.path.abspath(os.path.join('..'))
 from distiller.quantization import Quantizer
 from distiller.quantization.quantizer import QBits, _ParamToQuant
 from distiller.quantization.quantizer import FP_BKP_PREFIX
diff --git a/tests/test_ranking.py b/tests/test_ranking.py
index d48e927..0efb6c4 100755
--- a/tests/test_ranking.py
+++ b/tests/test_ranking.py
@@ -16,15 +16,7 @@
 
 import logging
 import torch
-import os
-import sys
-try:
-    import distiller
-except ImportError:
-    module_path = os.path.abspath(os.path.join('..'))
-    if module_path not in sys.path:
-        sys.path.append(module_path)
-    import distiller
+import distiller
 import common  # common test code
 
 # Logging configuration
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
index 130dcdc..6034c02 100755
--- a/tests/test_thresholding.py
+++ b/tests/test_thresholding.py
@@ -1,10 +1,20 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
 import torch
-import os
-import sys
 import pytest
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
-    sys.path.append(module_path)
 import distiller
 
 
-- 
GitLab