diff --git a/examples/auto_compression/amc/README.md b/examples/auto_compression/amc/README.md
index acf8ed2044cf0e5ae1a77dfacca3ed3ae8823ee8..94be27d325aa6e6c520803bae5ef5452c7a25189 100755
--- a/examples/auto_compression/amc/README.md
+++ b/examples/auto_compression/amc/README.md
@@ -32,6 +32,35 @@ AMC [1] trains a Deep Deterministic Policy Gradient (DDPG) RL agent to compress
 
 We thank Prof. Song Han and his team for their [help](https://github.com/mit-han-lab/amc-compressed-models) with certain critical parts of this implementation.  However, all bugs in interpretation and/or implementation are ours ;-).
 
+## Installation
+Our AMC implementation is designed such that you can easily switch between RL libraries (i.e. different agent implementations). The `--amc-rllib` argument instructs us which library to use.
+For `--amc-rllib=coach` you need to install coach in your Python virtual-env. The "hanlab" library, `--amc-rllib=hanlab`, refers to HAN Lab's DDPG agent implementation.
+
+### Using Intel AI Lab's Coach
+Some of the features required for AMC are not yet in the official [Coach](https://github.com/NervanaSystems/coach) release, so you should use the `master` branch.
+Therefore, follow Coach's [installation instructions](https://github.com/NervanaSystems/coach#installation) for a development environment, and use the `master` branch.
+<br>
+Coach uses TensorFlow and Distiller uses PyTorch, and the two frameworks do not share GPUs well.  The easiest work-around for this is to execute Coach code (the RL agent) on the CPU and Distiller code on the GPU(s).
+<br>
+To do this, please uninstall TensorFlow:<br>
+`$ pip uninstall tensorflow tensorflow-gpu`
+<br>
+and then reinstall TensorFlow: <br>
+`$ pip install tensorflow ` 
+ 
+### Using MIT HAN Lab's DDPG agent
+We integrated MIT HAN's Lab's AMC DDPG agent directly into the code base, so there is no explicit effort required to use it.   
+
+## Notable algorithm details
+
+### Feature-map reconstruction
+
+### DDPG Agent
+- AMC uses truncated normal distribution for exploration policy
+- AMC uses reward shaping 
+
+## References
+
 [1] AMC: AutoML for Model Compression and Acceleration on Mobile Devices.<br>
      Yihui He, Ji Lin, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han.<br>
      In Proceedings of the European Conference on Computer Vision (ECCV), 2018.<br>
diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py
index 46d633c265e109107ff5e6e30e277727ad5d7ece..d30473e27b082e6caa26ff68d4f6d87129b72a8b 100755
--- a/examples/auto_compression/amc/amc.py
+++ b/examples/auto_compression/amc/amc.py
@@ -15,7 +15,7 @@
 #
 
 """
-$ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc --amc-procol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=private -j=1
+$ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc --amc-procol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=hanlab -j=1
 """
 
 
@@ -143,9 +143,9 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
         env2 = create_environment()
         steps_per_episode = env1.steps_per_episode
         rl.solve(env1, env2)
-    elif args.amc_rllib == "private":
-        from rl_libs.private import private_if
-        rl = private_if.RlLibInterface()
+    elif args.amc_rllib == "hanlab":
+        from rl_libs.hanlab import hanlab_if
+        rl = hanlab_if.RlLibInterface()
         args.observation_len = len(Observation._fields)
         rl.solve(env1, args)
     elif args.amc_rllib == "coach":
diff --git a/examples/auto_compression/amc/amc_args.py b/examples/auto_compression/amc/amc_args.py
index 38ecb0af5e79e4279ad8aaffcae89f4771cba1af..2238691a02976a5bc6f0be8b8cded28ee151e073 100755
--- a/examples/auto_compression/amc/amc_args.py
+++ b/examples/auto_compression/amc/amc_args.py
@@ -46,7 +46,7 @@ def add_automl_args(argparser):
                        default="l1-rank", help="The pruning method")
     group.add_argument('--amc-rllib', choices=["coach",
                                                "spinningup",
-                                               "private",
+                                               "hanlab",
                                                "random"],
                        default=None, help="Choose which RL library to use")
     group.add_argument('--amc-group-size', type=int, default=1,
diff --git a/examples/auto_compression/amc/auto_compression_channels.yaml b/examples/auto_compression/amc/auto_compression_channels.yaml
index 558c9754f1036454074ade0057c93c723764d963..737c3e492ee488af4221d4591c94cd0e577cbb1e 100755
--- a/examples/auto_compression/amc/auto_compression_channels.yaml
+++ b/examples/auto_compression/amc/auto_compression_channels.yaml
@@ -1,6 +1,6 @@
 # rl_lib:
 #     # Choose which RL library to use: Coach from Intel AI Lab, or Spinup from OpenAI
-#     name: private # one of: coach, spinningup, private
+#     name: hanlab # one of: coach, spinningup, hanlab
 
 network:
   mobilenet:
diff --git a/examples/auto_compression/amc/auto_compression_filters.yaml b/examples/auto_compression/amc/auto_compression_filters.yaml
index c80d59c43074104db37d9d43985c883c3d685040..16293488f626e447e51763fe94c33896ffe0b2ce 100755
--- a/examples/auto_compression/amc/auto_compression_filters.yaml
+++ b/examples/auto_compression/amc/auto_compression_filters.yaml
@@ -1,6 +1,6 @@
 # rl_lib:
 #     # Choose which RL library to use: Coach from Intel AI Lab, or Spinup from OpenAI
-#     name: private # one of: coach, spinningup, private
+#     name: hanlab # one of: coach, spinningup, hanlab
 
 network:
   mobilenet:
diff --git a/examples/auto_compression/amc/jupyter/amc_plain20.ipynb b/examples/auto_compression/amc/jupyter/amc_plain20.ipynb
index 3ccfb8c76fc5b5b33d757af1eee7341f7b79b5df..e58ec546db892c9e263b6f4c25b4d58d97c45192 100644
--- a/examples/auto_compression/amc/jupyter/amc_plain20.ipynb
+++ b/examples/auto_compression/amc/jupyter/amc_plain20.ipynb
@@ -79,7 +79,7 @@
     "\n",
     "The command-line is provided below:\n",
     "    \n",
-    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=private -j=1\n",
+    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=hanlab -j=1\n",
     "    \n",
     "Each colored line represents one experiment execution instance. We plot the search-Top1 score of discovered networks as the RL-based AMC system learns to find better compressed networks. You might be impressed by:\n",
     "* The variability in behavior, which is typical for RL algorithms.\n",
@@ -358,7 +358,7 @@
    "source": [
     "## Using a different reward function\n",
     "\n",
-    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private-punish compress_classifier.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=punish-agent --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=../automated_deep_compression/auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=private -j=1"
+    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private-punish compress_classifier.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=punish-agent --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=../automated_deep_compression/auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=hanlab -j=1"
    ]
   },
   {
@@ -448,7 +448,7 @@
     "        validation=5000\n",
     "        test=750\n",
     "\n",
-    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=1.0 --etrs=0.01 --amc-rllib=private -j=1\n",
+    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=1.0 --etrs=0.01 --amc-rllib=hanlab -j=1\n",
     "    \n",
     "    time python parallel-finetune.py --scan-dir=${AMC_EXP_PATH}/plain20-ddpg-private/2019.08.01-181040 --arch=plain20_cifar --lr=0.1 --vs=0 -p=50 --compress=plain20_fine_tune.yaml ${CIFAR10_PATH} -j=1 --epochs=60 --output-csv=ft_60epoch_results.csv --processes=16"
    ]
@@ -491,7 +491,7 @@
    "source": [
     "Now repeating the original experiment, with checkpoint savings so we can fine-tune the best ones\n",
     "\n",
-    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.01 --amc-rllib=private -j=1 --amc-save-chkpts\n",
+    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py  --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.01 --amc-rllib=hanlab -j=1 --amc-save-chkpts\n",
     "    \n",
     "    time python ../../classifier_compression/parallel-finetune.py --scan-dir=${AMC_EXP_PATH}/plain20-ddpg-private/2019.08.03-000628 --arch=plain20_cifar --lr=0.1 --vs=0 -p=50 --compress=../plain20_fine_tune.yaml ${CIFAR10_PATH} -j=1 --epochs=60 --output-csv=ft_60epoch_results.csv --processes=16 --top-performing-chkpts"
    ]
diff --git a/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb b/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb
index 4124e146c7ed022fa8df591b5e5854953032e139..4232c7baed315ff371399cce74e2fb7407ecf058 100644
--- a/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb
+++ b/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb
@@ -79,7 +79,7 @@
     "\n",
     "The command-line is provided below:\n",
     "    \n",
-    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/resnet20-ddpg-private amc.py  --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=private -j=1\n",
+    "    time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/resnet20-ddpg-private amc.py  --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=hanlab -j=1\n",
     "    \n",
     "Each colored line represents one experiment execution instance. We plot the search-Top1 score of discovered networks as the RL-based AMC system learns to find better compressed networks. You might be impressed by:\n",
     "* The variability in behavior, which is typical for RL algorithms.\n",
@@ -233,13 +233,13 @@
   "pycharm": {
    "stem_cell": {
     "cell_type": "raw",
+    "source": [],
     "metadata": {
      "collapsed": false
-    },
-    "source": []
+    }
    }
   }
  },
  "nbformat": 4,
  "nbformat_minor": 2
-}
+}
\ No newline at end of file
diff --git a/examples/auto_compression/amc/rl_libs/hanlab/README.md b/examples/auto_compression/amc/rl_libs/hanlab/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9aebcd9ea12a4be30774068d0a10f61d00e00fd9
--- /dev/null
+++ b/examples/auto_compression/amc/rl_libs/hanlab/README.md
@@ -0,0 +1,15 @@
+The code in this directory originates from HAN Lab's [AMC github repository](https://github.com/mit-han-lab/amc-release).
+
+We copied the DDPG files from [HAN Lab's github](https://github.com/mit-han-lab/amc-release) to `distiller/examples/auto_compression/amc/rl_libs/private/`.<br>
+Specifically:
+- `mit-han-lab/amc-release/tree/master/lib/agent.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private`
+- `mit-han-lab/amc-release/tree/master/lib/memory.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private`
+- `mit-han-lab/amc-release/tree/master/lib/utils.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private`
+
+Function `train()` was copied from `mit-han-lab/amc-release/tree/master/lib/amc_search.py` to the new file in `distiller/examples/auto_compression/amc/rl_libs/private/agent.py`.
+
+Some non-functional changes were introduced in order for this to compile under Distiller.
+
+The MIT license was copies to distiller/licenses/hanlab-amc-license.txt.
+
+ 
\ No newline at end of file
diff --git a/examples/auto_compression/amc/rl_libs/private/__init__.py b/examples/auto_compression/amc/rl_libs/hanlab/__init__.py
similarity index 100%
rename from examples/auto_compression/amc/rl_libs/private/__init__.py
rename to examples/auto_compression/amc/rl_libs/hanlab/__init__.py
diff --git a/examples/auto_compression/amc/rl_libs/hanlab/agent.py b/examples/auto_compression/amc/rl_libs/hanlab/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac2045c64bb535e7168fd75a9373672c8b128b1
--- /dev/null
+++ b/examples/auto_compression/amc/rl_libs/hanlab/agent.py
@@ -0,0 +1,298 @@
+# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices"
+# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han
+# {jilin, songhan}@mit.edu
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.optim import Adam
+import os
+from .memory import SequentialMemory
+from .utils import to_numpy, to_tensor
+from copy import deepcopy
+
+criterion = nn.MSELoss()
+USE_CUDA = torch.cuda.is_available()
+
+
+class Actor(nn.Module):
+    def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
+        super(Actor, self).__init__()
+        self.fc1 = nn.Linear(nb_states, hidden1)
+        self.fc2 = nn.Linear(hidden1, hidden2)
+        self.fc3 = nn.Linear(hidden2, nb_actions)
+        self.relu = nn.ReLU()
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        out = self.fc1(x)
+        out = self.relu(out)
+        out = self.fc2(out)
+        out = self.relu(out)
+        out = self.fc3(out)
+        out = self.sigmoid(out)
+        return out
+
+
+class Critic(nn.Module):
+    def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
+        super(Critic, self).__init__()
+        self.fc11 = nn.Linear(nb_states, hidden1)
+        self.fc12 = nn.Linear(nb_actions, hidden1)
+        self.fc2 = nn.Linear(hidden1, hidden2)
+        self.fc3 = nn.Linear(hidden2, 1)
+        self.relu = nn.ReLU()
+
+    def forward(self, xs):
+        x, a = xs
+        out = self.fc11(x) + self.fc12(a)
+        out = self.relu(out)
+        out = self.fc2(out)
+        out = self.relu(out)
+        out = self.fc3(out)
+        return out
+
+
+class DDPG(object):
+    def __init__(self, nb_states, nb_actions, args):
+
+        self.nb_states = nb_states
+        self.nb_actions = nb_actions
+
+        # Create Actor and Critic Network
+        net_cfg = {
+            'hidden1': args.hidden1,
+            'hidden2': args.hidden2,
+        }
+        self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
+        self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)
+        self.actor_optim = Adam(self.actor.parameters(), lr=args.lr_a)
+
+        self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
+        self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)
+        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr_c)
+
+        self.hard_update(self.actor_target, self.actor)  # Make sure target is with the same weight
+        self.hard_update(self.critic_target, self.critic)
+
+        # Create replay buffer
+        self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length)
+
+        # Hyper-parameters
+        self.batch_size = args.bsize
+        self.tau = args.tau
+        self.discount = args.discount
+        self.depsilon = 1.0 / args.epsilon
+        self.lbound = 0.  # args.lbound
+        self.rbound = 1.  # args.rbound
+
+        # noise
+        self.init_delta = args.init_delta
+        self.delta_decay = args.delta_decay
+        self.warmup = args.warmup
+
+        #
+        self.epsilon = 1.0
+        self.is_training = True
+
+        #
+        if USE_CUDA: self.cuda()
+
+        # moving average baseline
+        self.moving_average = None
+        self.moving_alpha = 0.5  # based on batch, so small
+
+    def update_policy(self):
+        # Sample batch
+        state_batch, action_batch, reward_batch, \
+        next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size)
+
+        # normalize the reward
+        batch_mean_reward = np.mean(reward_batch)
+        if self.moving_average is None:
+            self.moving_average = batch_mean_reward
+        else:
+            self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average)
+        reward_batch -= self.moving_average
+
+        # Prepare for the target q batch
+        with torch.no_grad():
+            next_q_values = self.critic_target([
+                to_tensor(next_state_batch),
+                self.actor_target(to_tensor(next_state_batch)),
+            ])
+
+        target_q_batch = to_tensor(reward_batch) + \
+                         self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values
+
+        # Critic update
+        self.critic.zero_grad()
+
+        q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)])
+
+        value_loss = criterion(q_batch, target_q_batch)
+        value_loss.backward()
+        self.critic_optim.step()
+
+        # Actor update
+        self.actor.zero_grad()
+
+        policy_loss = -self.critic([
+            to_tensor(state_batch),
+            self.actor(to_tensor(state_batch))
+        ])
+
+        policy_loss = policy_loss.mean()
+        policy_loss.backward()
+        self.actor_optim.step()
+
+        # Target update
+        self.soft_update(self.actor_target, self.actor)
+        self.soft_update(self.critic_target, self.critic)
+
+    def eval(self):
+        self.actor.eval()
+        self.actor_target.eval()
+        self.critic.eval()
+        self.critic_target.eval()
+
+    def cuda(self):
+        self.actor.cuda()
+        self.actor_target.cuda()
+        self.critic.cuda()
+        self.critic_target.cuda()
+
+    def observe(self, r_t, s_t, s_t1, a_t, done):
+        if self.is_training:
+            self.memory.append(s_t, a_t, r_t, done)  # save to memory
+            # self.s_t = s_t1
+
+    def random_action(self):
+        action = np.random.uniform(self.lbound, self.rbound, self.nb_actions)
+        return action
+
+    def select_action(self, s_t, episode):
+        # assert episode >= self.warmup, 'Episode: {} warmup: {}'.format(episode, self.warmup)
+        action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0)
+        delta = self.init_delta * (self.delta_decay ** (episode - self.warmup))
+        # action += self.is_training * max(self.epsilon, 0) * self.random_process.sample()
+        action = self.sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta)
+        action = np.clip(action, self.lbound, self.rbound)
+
+        # self.a_t = action
+        return action
+
+    def reset(self, obs):
+        pass
+        # self.s_t = obs
+        # self.random_process.reset_states()
+
+    def load_weights(self, output):
+        if output is None: return
+
+        self.actor.load_state_dict(
+            torch.load('{}/actor.pkl'.format(output))
+        )
+
+        self.critic.load_state_dict(
+            torch.load('{}/critic.pkl'.format(output))
+        )
+
+    def save_model(self, output):
+        torch.save(
+            self.actor.state_dict(),
+            '{}/actor.pkl'.format(output)
+        )
+        torch.save(
+            self.critic.state_dict(),
+            '{}/critic.pkl'.format(output)
+        )
+
+    def soft_update(self, target, source):
+        for target_param, param in zip(target.parameters(), source.parameters()):
+            target_param.data.copy_(
+                target_param.data * (1.0 - self.tau) + param.data * self.tau
+            )
+
+    def hard_update(self, target, source):
+        for target_param, param in zip(target.parameters(), source.parameters()):
+            target_param.data.copy_(param.data)
+
+    def sample_from_truncated_normal_distribution(self, lower, upper, mu, sigma, size=1):
+        from scipy import stats
+        return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size)
+
+
+def train(num_episode, agent, env, output, warmup):
+    text_writer = open(os.path.join(output, 'sh_log.txt'), 'w')
+    print('=> Output path: {}...'.format(output))
+
+    agent.is_training = True
+    step = episode = episode_steps = 0
+    episode_reward = 0.
+    observation = None
+    T = []  # trajectory
+    while episode < num_episode:  # counting based on episode
+        # reset if it is the start of episode
+        if observation is None:
+            observation = deepcopy(env.reset())
+            agent.reset(observation)
+
+        # agent pick action ...
+        if episode <= warmup:
+            action = agent.random_action()
+        else:
+            action = agent.select_action(observation, episode=episode)
+
+        # env response with next_observation, reward, terminate_info
+        observation2, reward, done, info = env.step(action)
+        env.render()
+        observation2 = deepcopy(observation2)
+
+        T.append([reward, deepcopy(observation), deepcopy(observation2), action, done])
+
+        # [optional] save intermideate model
+        #if episode % int(num_episode / 3) == 0:
+         #   agent.save_model(output)
+
+        # update
+        step += 1
+        episode_steps += 1
+        episode_reward += reward
+        observation = deepcopy(observation2)
+
+        if done:  # end of episode
+            print('#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(episode, episode_reward,
+                                                                                 info['accuracy'],
+                                                                                 info['compress_ratio']))
+            text_writer.write(
+                '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(episode, episode_reward,
+                                                                                 info['accuracy'],
+                                                                                 info['compress_ratio']))
+            final_reward = T[-1][0]
+            # print('final_reward: {}'.format(final_reward))
+            # agent observe and update policy
+            for r_t, s_t, s_t1, a_t, done in T:
+                agent.observe(final_reward, s_t, s_t1, a_t, done)
+                if episode > warmup:
+                    agent.update_policy()
+
+            # reset
+            observation = None
+            episode_steps = 0
+            episode_reward = 0.
+            episode += 1
+            T = []
+
+            # tfwriter.add_scalar('reward/last', final_reward, episode)
+            # tfwriter.add_scalar('reward/best', env.best_reward, episode)
+            # tfwriter.add_scalar('info/accuracy', info['accuracy'], episode)
+            # tfwriter.add_scalar('info/compress_ratio', info['compress_ratio'], episode)
+            # tfwriter.add_text('info/best_policy', str(env.best_strategy), episode)
+            # # record the preserve rate for each layer
+            # for i, preserve_rate in enumerate(env.strategy):
+            #     tfwriter.add_scalar('preserve_rate/{}'.format(i), preserve_rate, episode)
+
+            text_writer.write('best reward: {}\n'.format(env.best_reward))
+            #text_writer.write('best policy: {}\n'.format(env.best_strategy))
+    text_writer.close()
diff --git a/examples/auto_compression/amc/rl_libs/private/private_if.py b/examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py
similarity index 90%
rename from examples/auto_compression/amc/rl_libs/private/private_if.py
rename to examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py
index 2ca43cebad36d05bf02b08261d7eca7be9cd4849..0cbb4df6a58ca534eacad9033e01b525ca13dda8 100755
--- a/examples/auto_compression/amc/rl_libs/private/private_if.py
+++ b/examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 #
 
-from examples.auto_compression.amc.rl_libs.private.agent import DDPG, train
+from examples.auto_compression.amc.rl_libs.hanlab.agent import DDPG, train
 import logging
 
 
@@ -27,10 +27,10 @@ class ArgsContainer(object):
 
 
 class RlLibInterface(object):
-    """Interface to a private DDPG impelementation."""
+    """Interface to a hanlab DDPG impelementation."""
 
     def solve(self, env, args):
-        msglogger.info("AMC: Using private")
+        msglogger.info("AMC: Using hanlab")
         
         agent_args = ArgsContainer()
         agent_args.bsize = args.batch_size
diff --git a/examples/auto_compression/amc/rl_libs/hanlab/memory.py b/examples/auto_compression/amc/rl_libs/hanlab/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..493740c6351bc19b7648e7ba5fce9735bb783de6
--- /dev/null
+++ b/examples/auto_compression/amc/rl_libs/hanlab/memory.py
@@ -0,0 +1,228 @@
+# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices"
+# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han
+# {jilin, songhan}@mit.edu
+
+from __future__ import absolute_import
+from collections import deque, namedtuple
+import warnings
+import random
+
+import numpy as np
+
+# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py
+
+# This is to be understood as a transition: Given `state0`, performing `action`
+# yields `reward` and results in `state1`, which might be `terminal`.
+Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1')
+
+
+def sample_batch_indexes(low, high, size):
+    if high - low >= size:
+        # We have enough data. Draw without replacement, that is each index is unique in the
+        # batch. We cannot use `np.random.choice` here because it is horribly inefficient as
+        # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion.
+        # `random.sample` does the same thing (drawing without replacement) and is way faster.
+        r = range(low, high)
+        batch_idxs = random.sample(r, size)
+    else:
+        # Not enough data. Help ourselves with sampling from the range, but the same index
+        # can occur multiple times. This is not good and should be avoided by picking a
+        # large enough warm-up phase.
+        warnings.warn(
+            'Not enough entries to sample without replacement. '
+            'Consider increasing your warm-up phase to avoid oversampling!')
+        batch_idxs = np.random.random_integers(low, high - 1, size=size)
+    assert len(batch_idxs) == size
+    return batch_idxs
+
+
+class RingBuffer(object):
+    def __init__(self, maxlen):
+        self.maxlen = maxlen
+        self.start = 0
+        self.length = 0
+        self.data = [None for _ in range(maxlen)]
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        if idx < 0 or idx >= self.length:
+            raise KeyError()
+        return self.data[(self.start + idx) % self.maxlen]
+
+    def append(self, v):
+        if self.length < self.maxlen:
+            # We have space, simply increase the length.
+            self.length += 1
+        elif self.length == self.maxlen:
+            # No space, "remove" the first item.
+            self.start = (self.start + 1) % self.maxlen
+        else:
+            # This should never happen.
+            raise RuntimeError()
+        self.data[(self.start + self.length - 1) % self.maxlen] = v
+
+
+def zeroed_observation(observation):
+    if hasattr(observation, 'shape'):
+        return np.zeros(observation.shape)
+    elif hasattr(observation, '__iter__'):
+        out = []
+        for x in observation:
+            out.append(zeroed_observation(x))
+        return out
+    else:
+        return 0.
+
+
+class Memory(object):
+    def __init__(self, window_length, ignore_episode_boundaries=False):
+        self.window_length = window_length
+        self.ignore_episode_boundaries = ignore_episode_boundaries
+
+        self.recent_observations = deque(maxlen=window_length)
+        self.recent_terminals = deque(maxlen=window_length)
+
+    def sample(self, batch_size, batch_idxs=None):
+        raise NotImplementedError()
+
+    def append(self, observation, action, reward, terminal, training=True):
+        self.recent_observations.append(observation)
+        self.recent_terminals.append(terminal)
+
+    def get_recent_state(self, current_observation):
+        # This code is slightly complicated by the fact that subsequent observations might be
+        # from different episodes. We ensure that an experience never spans multiple episodes.
+        # This is probably not that important in practice but it seems cleaner.
+        state = [current_observation]
+        idx = len(self.recent_observations) - 1
+        for offset in range(0, self.window_length - 1):
+            current_idx = idx - offset
+            current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False
+            if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
+                # The previously handled observation was terminal, don't add the current one.
+                # Otherwise we would leak into a different episode.
+                break
+            state.insert(0, self.recent_observations[current_idx])
+        while len(state) < self.window_length:
+            state.insert(0, zeroed_observation(state[0]))
+        return state
+
+    def get_config(self):
+        config = {
+            'window_length': self.window_length,
+            'ignore_episode_boundaries': self.ignore_episode_boundaries,
+        }
+        return config
+
+
+class SequentialMemory(Memory):
+    def __init__(self, limit, **kwargs):
+        super(SequentialMemory, self).__init__(**kwargs)
+
+        self.limit = limit
+
+        # Do not use deque to implement the memory. This data structure may seem convenient but
+        # it is way too slow on random access. Instead, we use our own ring buffer implementation.
+        self.actions = RingBuffer(limit)
+        self.rewards = RingBuffer(limit)
+        self.terminals = RingBuffer(limit)
+        self.observations = RingBuffer(limit)
+
+    def sample(self, batch_size, batch_idxs=None):
+        if batch_idxs is None:
+            # Draw random indexes such that we have at least a single entry before each
+            # index.
+            batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size)
+        batch_idxs = np.array(batch_idxs) + 1
+        assert np.min(batch_idxs) >= 1
+        assert np.max(batch_idxs) < self.nb_entries
+        assert len(batch_idxs) == batch_size
+
+        # Create experiences
+        experiences = []
+        for idx in batch_idxs:
+            terminal0 = self.terminals[idx - 2] if idx >= 2 else False
+            while terminal0:
+                # Skip this transition because the environment was reset here. Select a new, random
+                # transition and use this instead. This may cause the batch to contain the same
+                # transition twice.
+                idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
+                terminal0 = self.terminals[idx - 2] if idx >= 2 else False
+            assert 1 <= idx < self.nb_entries
+
+            # This code is slightly complicated by the fact that subsequent observations might be
+            # from different episodes. We ensure that an experience never spans multiple episodes.
+            # This is probably not that important in practice but it seems cleaner.
+            state0 = [self.observations[idx - 1]]
+            for offset in range(0, self.window_length - 1):
+                current_idx = idx - 2 - offset
+                current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
+                if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
+                    # The previously handled observation was terminal, don't add the current one.
+                    # Otherwise we would leak into a different episode.
+                    break
+                state0.insert(0, self.observations[current_idx])
+            while len(state0) < self.window_length:
+                state0.insert(0, zeroed_observation(state0[0]))
+            action = self.actions[idx - 1]
+            reward = self.rewards[idx - 1]
+            terminal1 = self.terminals[idx - 1]
+
+            # Okay, now we need to create the follow-up state. This is state0 shifted on timestep
+            # to the right. Again, we need to be careful to not include an observation from the next
+            # episode if the last state is terminal.
+            state1 = [np.copy(x) for x in state0[1:]]
+            state1.append(self.observations[idx])
+
+            assert len(state0) == self.window_length
+            assert len(state1) == len(state0)
+            experiences.append(Experience(state0=state0, action=action, reward=reward,
+                                          state1=state1, terminal1=terminal1))
+        assert len(experiences) == batch_size
+        return experiences
+
+    def sample_and_split(self, batch_size, batch_idxs=None):
+        experiences = self.sample(batch_size, batch_idxs)
+
+        state0_batch = []
+        reward_batch = []
+        action_batch = []
+        terminal1_batch = []
+        state1_batch = []
+        for e in experiences:
+            state0_batch.append(e.state0)
+            state1_batch.append(e.state1)
+            reward_batch.append(e.reward)
+            action_batch.append(e.action)
+            terminal1_batch.append(0. if e.terminal1 else 1.)
+
+        # Prepare and validate parameters.
+        state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1)
+        state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1)
+        terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1)
+        reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1)
+        action_batch = np.array(action_batch, 'double').reshape(batch_size, -1)
+
+        return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch
+
+    def append(self, observation, action, reward, terminal, training=True):
+        super(SequentialMemory, self).append(observation, action, reward, terminal, training=training)
+
+        # This needs to be understood as follows: in `observation`, take `action`, obtain `reward`
+        # and weather the next state is `terminal` or not.
+        if training:
+            self.observations.append(observation)
+            self.actions.append(action)
+            self.rewards.append(reward)
+            self.terminals.append(terminal)
+
+    @property
+    def nb_entries(self):
+        return len(self.observations)
+
+    def get_config(self):
+        config = super(SequentialMemory, self).get_config()
+        config['limit'] = self.limit
+        return config
diff --git a/examples/auto_compression/amc/rl_libs/hanlab/utils.py b/examples/auto_compression/amc/rl_libs/hanlab/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cdbb05012d64d8115a67317859d2f5876b7f0ab
--- /dev/null
+++ b/examples/auto_compression/amc/rl_libs/hanlab/utils.py
@@ -0,0 +1,261 @@
+# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices"
+# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han
+# {jilin, songhan}@mit.edu
+
+import os
+import torch
+import time
+import sys
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        if self.count > 0:
+            self.avg = self.sum / self.count
+
+    def accumulate(self, val, n=1):
+        self.sum += val
+        self.count += n
+        if self.count > 0:
+            self.avg = self.sum / self.count
+
+
+class TextLogger(object):
+    """Write log immediately to the disk"""
+    def __init__(self, filepath):
+        self.f = open(filepath, 'w')
+        self.fid = self.f.fileno()
+        self.filepath = filepath
+
+    def close(self):
+        self.f.close()
+
+    def write(self, content):
+        self.f.write(content)
+        self.f.flush()
+        os.fsync(self.fid)
+
+    def write_buf(self, content):
+        self.f.write(content)
+
+    def print_and_write(self, content):
+        print(content)
+        self.write(content+'\n')
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the precision@k for the specified values of k"""
+    batch_size = target.size(0)
+    num = output.size(1)
+    target_topk = []
+    appendices = []
+    for k in topk:
+        if k <= num:
+            target_topk.append(k)
+        else:
+            appendices.append([0.0])
+    topk = target_topk
+    maxk = max(topk)
+    _, pred = output.topk(maxk, 1, True, True)
+    pred = pred.t()
+    correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+    res = []
+    for k in topk:
+        correct_k = correct[:k].view(-1).float().sum(0)
+        res.append(correct_k.mul_(100.0 / batch_size))
+    return res + appendices
+
+
+def to_numpy(var):
+    use_cuda = torch.cuda.is_available()
+    return var.cpu().data.numpy() if use_cuda else var.data.numpy()
+
+
+def to_tensor(ndarray, requires_grad=False):  # return a float tensor by default
+    tensor = torch.from_numpy(ndarray).float()  # by default does not require grad
+    if requires_grad:
+        tensor.requires_grad_()
+    return tensor.cuda() if torch.cuda.is_available() else tensor
+
+
+def measure_layer_for_pruning(layer, x):
+    def get_layer_type(layer):
+        layer_str = str(layer)
+        return layer_str[:layer_str.find('(')].strip()
+
+    def get_layer_param(model):
+        import operator
+        import functools
+
+        return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()])
+
+    multi_add = 1
+    type_name = get_layer_type(layer)
+
+    # ops_conv
+    if type_name in ['Conv2d']:
+        out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
+                    layer.stride[0] + 1)
+        out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
+                    layer.stride[1] + 1)
+        layer.flops = layer.in_channels * layer.out_channels * layer.kernel_size[0] *  \
+                    layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
+        layer.params = get_layer_param(layer)
+    # ops_linear
+    elif type_name in ['Linear']:
+        weight_ops = layer.weight.numel() * multi_add
+        bias_ops = layer.bias.numel()
+        layer.flops = weight_ops + bias_ops
+        layer.params = get_layer_param(layer)
+    return
+
+
+def least_square_sklearn(X, Y):
+    from sklearn.linear_model import LinearRegression
+    reg = LinearRegression(fit_intercept=False)
+    reg.fit(X, Y)
+    return reg.coef_
+
+
+def get_output_folder(parent_dir, env_name):
+    """Return save folder.
+    Assumes folders in the parent_dir have suffix -run{run
+    number}. Finds the highest run number and sets the output folder
+    to that number + 1. This is just convenient so that if you run the
+    same script multiple times tensorboard can plot all of the results
+    on the same plots with different names.
+    Parameters
+    ----------
+    parent_dir: str
+      Path of the directory containing all experiment runs.
+    Returns
+    -------
+    parent_dir/run_dir
+      Path to this run's save directory.
+    """
+    os.makedirs(parent_dir, exist_ok=True)
+    experiment_id = 0
+    for folder_name in os.listdir(parent_dir):
+        if not os.path.isdir(os.path.join(parent_dir, folder_name)):
+            continue
+        try:
+            folder_name = int(folder_name.split('-run')[-1])
+            if folder_name > experiment_id:
+                experiment_id = folder_name
+        except:
+            pass
+    experiment_id += 1
+
+    parent_dir = os.path.join(parent_dir, env_name)
+    parent_dir = parent_dir + '-run{}'.format(experiment_id)
+    os.makedirs(parent_dir, exist_ok=True)
+    return parent_dir
+
+
+
+# Custom progress bar
+_, term_width = os.popen('stty size', 'r').read().split()
+term_width = int(term_width)
+TOTAL_BAR_LENGTH = 40.
+last_time = time.time()
+begin_time = last_time
+
+
+def progress_bar(current, total, msg=None):
+    def format_time(seconds):
+        days = int(seconds / 3600 / 24)
+        seconds = seconds - days * 3600 * 24
+        hours = int(seconds / 3600)
+        seconds = seconds - hours * 3600
+        minutes = int(seconds / 60)
+        seconds = seconds - minutes * 60
+        secondsf = int(seconds)
+        seconds = seconds - secondsf
+        millis = int(seconds * 1000)
+
+        f = ''
+        i = 1
+        if days > 0:
+            f += str(days) + 'D'
+            i += 1
+        if hours > 0 and i <= 2:
+            f += str(hours) + 'h'
+            i += 1
+        if minutes > 0 and i <= 2:
+            f += str(minutes) + 'm'
+            i += 1
+        if secondsf > 0 and i <= 2:
+            f += str(secondsf) + 's'
+            i += 1
+        if millis > 0 and i <= 2:
+            f += str(millis) + 'ms'
+            i += 1
+        if f == '':
+            f = '0ms'
+        return f
+
+    global last_time, begin_time
+    if current == 0:
+        begin_time = time.time()  # Reset for new bar.
+
+    cur_len = int(TOTAL_BAR_LENGTH*current/total)
+    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
+
+    sys.stdout.write(' [')
+    for i in range(cur_len):
+        sys.stdout.write('=')
+    sys.stdout.write('>')
+    for i in range(rest_len):
+        sys.stdout.write('.')
+    sys.stdout.write(']')
+
+    cur_time = time.time()
+    step_time = cur_time - last_time
+    last_time = cur_time
+    tot_time = cur_time - begin_time
+
+    L = []
+    L.append('  Step: %s' % format_time(step_time))
+    L.append(' | Tot: %s' % format_time(tot_time))
+    if msg:
+        L.append(' | ' + msg)
+
+    msg = ''.join(L)
+    sys.stdout.write(msg)
+    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
+        sys.stdout.write(' ')
+
+    # Go back to the center of the bar.
+    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
+        sys.stdout.write('\b')
+    sys.stdout.write(' %d/%d ' % (current+1, total))
+
+    if current < total-1:
+        sys.stdout.write('\r')
+    else:
+        sys.stdout.write('\n')
+    sys.stdout.flush()
+
+# logging
+def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
+def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
+def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
+def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
+def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
+def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
+def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
+def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
\ No newline at end of file
diff --git a/licenses/hanlab-amc-license.txt b/licenses/hanlab-amc-license.txt
new file mode 100755
index 0000000000000000000000000000000000000000..8d5cc049cf16d6019ec9a2fd2d7d41797b1b3f3b
--- /dev/null
+++ b/licenses/hanlab-amc-license.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 MIT_Han_Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.