diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py index a04781150f18b553eeb3187b160511e9f1827931..590acf4c6bd86f1ac12041820616effa2889af33 100755 --- a/examples/automated_deep_compression/ADC.py +++ b/examples/automated_deep_compression/ADC.py @@ -65,7 +65,12 @@ import logging import numpy as np import torch import csv -import gym +try: + import gym +except ImportError as e: + print("WARNING: to use automated compression you will need to install extra packages") + print("See instructions in the header of examples/automated_deep_compression/ADC.py") + raise e from gym import spaces import distiller from apputils import SummaryGraph @@ -183,7 +188,7 @@ def amc_reward_fn(env, top1, top5, vloss, total_macs): experimental_reward_fn = harmonic_mean_reward_fn -def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): +def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): dataset = args.dataset arch = args.arch perform_thinning = True # args.amc_thinning diff --git a/examples/automated_deep_compression/__init__.py b/examples/automated_deep_compression/__init__.py index a05f8a23b16eacba69e0789740f2fbc3195ca653..5b2e97cc905519ff2486a6fca3d93945729d7f38 100755 --- a/examples/automated_deep_compression/__init__.py +++ b/examples/automated_deep_compression/__init__.py @@ -1,2 +1,5 @@ from .automl_args import add_automl_args -from .ADC import do_adc + +def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): + from .ADC import do_adc_internal + do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 6f18978bc9b2a2bb3992837b47c854f2efefb5ca..b3d9513114ee1edb9f70cab4e820730a958ca800 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -657,7 +657,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args): save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) optimizer_data = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay} - adc.ADC.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) + adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) def greedy(model, criterion, optimizer, loggers, args):