From ac9f61c06ac5299a7a00baf582d362570279fe15 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 14 Feb 2019 17:06:22 +0200
Subject: [PATCH] Fix automated-compression imports

To use automated compression you need to install several optional packages
which are not required for other use-cases.
This fix hides the import requirements for users who do not want to install
the extra packages.
---
 examples/automated_deep_compression/ADC.py             | 9 +++++++--
 examples/automated_deep_compression/__init__.py        | 5 ++++-
 examples/classifier_compression/compress_classifier.py | 2 +-
 3 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py
index a047811..590acf4 100755
--- a/examples/automated_deep_compression/ADC.py
+++ b/examples/automated_deep_compression/ADC.py
@@ -65,7 +65,12 @@ import logging
 import numpy as np
 import torch
 import csv
-import gym
+try:
+    import gym
+except ImportError as e:
+    print("WARNING: to use automated compression you will need to install extra packages")
+    print("See instructions in the header of examples/automated_deep_compression/ADC.py")
+    raise e
 from gym import spaces
 import distiller
 from apputils import SummaryGraph
@@ -183,7 +188,7 @@ def amc_reward_fn(env, top1, top5, vloss, total_macs):
 experimental_reward_fn = harmonic_mean_reward_fn
 
 
-def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
+def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
     dataset = args.dataset
     arch = args.arch
     perform_thinning = True  # args.amc_thinning
diff --git a/examples/automated_deep_compression/__init__.py b/examples/automated_deep_compression/__init__.py
index a05f8a2..5b2e97c 100755
--- a/examples/automated_deep_compression/__init__.py
+++ b/examples/automated_deep_compression/__init__.py
@@ -1,2 +1,5 @@
 from .automl_args import add_automl_args
-from .ADC import do_adc
+
+def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
+    from .ADC import do_adc_internal
+    do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 6f18978..b3d9513 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -657,7 +657,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
 
     save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir)
     optimizer_data = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay}
-    adc.ADC.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
+    adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
 
 
 def greedy(model, criterion, optimizer, loggers, args):
-- 
GitLab