diff --git a/examples/word_language_model/generate.py b/examples/word_language_model/generate.py index 9beb06c165ade5d736cf5e21a05615f76f8b7bc7..3849b6987c4a74d9ed59e78b4893ed452334b367 100755 --- a/examples/word_language_model/generate.py +++ b/examples/word_language_model/generate.py @@ -18,12 +18,6 @@ import data # dependency on distiller code. # It's a bit ironic that PyTorch's docs advise against this kind of serialization, # while PyTorch's samples use it: https://pytorch.org/docs/master/notes/serialization.html -import os -import sys -script_dir = os.path.dirname(__file__) -module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model') diff --git a/tests/common.py b/tests/common.py index a2283f1dd726537c0b899e3d518915caa6deb7e4..792e03343c8617f647026420f418f61092e4ddfa 100755 --- a/tests/common.py +++ b/tests/common.py @@ -13,13 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import os -import sys import torch -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller from distiller.models import create_model diff --git a/tests/test_basic.py b/tests/test_basic.py index 02091d06d875ee76f271c6344957ca6ca76b1927..6c06572bebb90d4f07d20f6d7189ea24350b86c6 100755 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -15,12 +15,7 @@ # import torch -import os -import sys import common -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller import distiller.models as models diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py index 42767bb6601e540a037e33680921513fc5c7d187..dd239147be3ba3fa38a93daa4d2842cea4775d1d 100644 --- a/tests/test_learning_rate.py +++ b/tests/test_learning_rate.py @@ -14,13 +14,7 @@ # limitations under the License. # -import os -import sys import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) - import torch from torch.optim import Optimizer from distiller.learning_rate import MultiStepMultiGammaLR diff --git a/tests/test_loss.py b/tests/test_loss.py index 6c197d09bf1d39ea68066aabf64d73696effe5e0..27452f8459425274d0b0e245c5891a14ad433c09 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -15,15 +15,10 @@ # import torch -import os -import sys import torch.nn as nn from copy import deepcopy import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) from distiller import ScheduledTrainingPolicy, CompressionScheduler from distiller.policy import PolicyLoss, LossComponent diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index bb77de8349272b903d7e2963fd2ec15fbf9600f2..505374826287b6a3656f794f936f9f9fcad5ebee 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -16,11 +16,6 @@ import logging import torch -import os -import sys -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller import pytest import common # common test code diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index aa51f2405f5aa8f34389c6531687153c70db4d8f..24148660a9a6fe93432e5fe3f4289e2164f4666a 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import os import pytest import torch import torch.testing from collections import OrderedDict -module_path = os.path.abspath(os.path.join('..')) from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, \ RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper import distiller.modules diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 443e45240083ad88108ebd4940764244c6a9c6e9..42a6ab7b39ef7fdc911649a5a05ee5f9128c39a2 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -17,8 +17,6 @@ from collections import namedtuple import numpy as np import logging import torch -import os -import sys import distiller import common import pytest diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py index f613e50dcfb4c054f8581f1fe88afa19cf53d911..fbc6e1a31686ca5a73b2c0930beb941e7e0856bf 100644 --- a/tests/test_quant_utils.py +++ b/tests/test_quant_utils.py @@ -16,11 +16,6 @@ import torch import pytest -import sys -import os -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) from distiller.quantization import q_utils as qu diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index f406ae422ac32551601e3def2f79c03e305be0c2..fdd47c77a87b04548de64cdb892f71b4e4139070 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -15,13 +15,11 @@ # import torch -import os import torch.nn as nn from copy import deepcopy from collections import OrderedDict import pytest -module_path = os.path.abspath(os.path.join('..')) from distiller.quantization import Quantizer from distiller.quantization.quantizer import QBits, _ParamToQuant from distiller.quantization.quantizer import FP_BKP_PREFIX diff --git a/tests/test_ranking.py b/tests/test_ranking.py index d48e927c1055d36735a8a87c06911dddd6922289..0efb6c4353e7e615f112e7ca568acead352b1371 100755 --- a/tests/test_ranking.py +++ b/tests/test_ranking.py @@ -16,15 +16,7 @@ import logging import torch -import os -import sys -try: - import distiller -except ImportError: - module_path = os.path.abspath(os.path.join('..')) - if module_path not in sys.path: - sys.path.append(module_path) - import distiller +import distiller import common # common test code # Logging configuration diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py index 130dcdc42f44b5cd1883315b9f5ddb3abbcd3953..6034c024508030ebe0fb9269f48f26288849724a 100755 --- a/tests/test_thresholding.py +++ b/tests/test_thresholding.py @@ -1,10 +1,20 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch -import os -import sys import pytest -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: - sys.path.append(module_path) import distiller