From 72ef9160b184f4d31d17e5d2fd59f312e1687740 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Mon, 8 Apr 2019 14:27:21 +0300 Subject: [PATCH] Removed sys.path modifications when importing distiller. (#224) --- examples/word_language_model/generate.py | 6 ------ tests/common.py | 6 ------ tests/test_basic.py | 5 ----- tests/test_learning_rate.py | 6 ------ tests/test_loss.py | 5 ----- tests/test_model_summary.py | 5 ----- tests/test_post_train_quant.py | 3 --- tests/test_pruning.py | 2 -- tests/test_quant_utils.py | 5 ----- tests/test_quantizer.py | 2 -- tests/test_ranking.py | 10 +--------- tests/test_thresholding.py | 20 +++++++++++++++----- 12 files changed, 16 insertions(+), 59 deletions(-) diff --git a/examples/word_language_model/generate.py b/examples/word_language_model/generate.py index 9beb06c..3849b69 100755 --- a/examples/word_language_model/generate.py +++ b/examples/word_language_model/generate.py @@ -18,12 +18,6 @@ import data # dependency on distiller code. # It's a bit ironic that PyTorch's docs advise against this kind of serialization, # while PyTorch's samples use it: https://pytorch.org/docs/master/notes/serialization.html -import os -import sys -script_dir = os.path.dirname(__file__) -module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model') diff --git a/tests/common.py b/tests/common.py index a2283f1..792e033 100755 --- a/tests/common.py +++ b/tests/common.py @@ -13,13 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import os -import sys import torch -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller from distiller.models import create_model diff --git a/tests/test_basic.py b/tests/test_basic.py index 02091d0..6c06572 100755 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -15,12 +15,7 @@ # import torch -import os -import sys import common -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller import distiller.models as models diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py index 42767bb..dd23914 100644 --- a/tests/test_learning_rate.py +++ b/tests/test_learning_rate.py @@ -14,13 +14,7 @@ # limitations under the License. # -import os -import sys import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) - import torch from torch.optim import Optimizer from distiller.learning_rate import MultiStepMultiGammaLR diff --git a/tests/test_loss.py b/tests/test_loss.py index 6c197d0..27452f8 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -15,15 +15,10 @@ # import torch -import os -import sys import torch.nn as nn from copy import deepcopy import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) from distiller import ScheduledTrainingPolicy, CompressionScheduler from distiller.policy import PolicyLoss, LossComponent diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index bb77de8..5053748 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -16,11 +16,6 @@ import logging import torch -import os -import sys -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller import pytest import common # common test code diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index aa51f24..2414866 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import os import pytest import torch import torch.testing from collections import OrderedDict -module_path = os.path.abspath(os.path.join('..')) from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, \ RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper import distiller.modules diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 443e452..42a6ab7 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -17,8 +17,6 @@ from collections import namedtuple import numpy as np import logging import torch -import os -import sys import distiller import common import pytest diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py index f613e50..fbc6e1a 100644 --- a/tests/test_quant_utils.py +++ b/tests/test_quant_utils.py @@ -16,11 +16,6 @@ import torch import pytest -import sys -import os -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) from distiller.quantization import q_utils as qu diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index f406ae4..fdd47c7 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -15,13 +15,11 @@ # import torch -import os import torch.nn as nn from copy import deepcopy from collections import OrderedDict import pytest -module_path = os.path.abspath(os.path.join('..')) from distiller.quantization import Quantizer from distiller.quantization.quantizer import QBits, _ParamToQuant from distiller.quantization.quantizer import FP_BKP_PREFIX diff --git a/tests/test_ranking.py b/tests/test_ranking.py index d48e927..0efb6c4 100755 --- a/tests/test_ranking.py +++ b/tests/test_ranking.py @@ -16,15 +16,7 @@ import logging import torch -import os -import sys -try: - import distiller -except ImportError: - module_path = os.path.abspath(os.path.join('..')) - if module_path not in sys.path: - sys.path.append(module_path) - import distiller +import distiller import common # common test code # Logging configuration diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py index 130dcdc..6034c02 100755 --- a/tests/test_thresholding.py +++ b/tests/test_thresholding.py @@ -1,10 +1,20 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch -import os -import sys import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller -- GitLab