diff --git a/examples/auto_compression/amc/README.md b/examples/auto_compression/amc/README.md index acf8ed2044cf0e5ae1a77dfacca3ed3ae8823ee8..94be27d325aa6e6c520803bae5ef5452c7a25189 100755 --- a/examples/auto_compression/amc/README.md +++ b/examples/auto_compression/amc/README.md @@ -32,6 +32,35 @@ AMC [1] trains a Deep Deterministic Policy Gradient (DDPG) RL agent to compress We thank Prof. Song Han and his team for their [help](https://github.com/mit-han-lab/amc-compressed-models) with certain critical parts of this implementation. However, all bugs in interpretation and/or implementation are ours ;-). +## Installation +Our AMC implementation is designed such that you can easily switch between RL libraries (i.e. different agent implementations). The `--amc-rllib` argument instructs us which library to use. +For `--amc-rllib=coach` you need to install coach in your Python virtual-env. The "hanlab" library, `--amc-rllib=hanlab`, refers to HAN Lab's DDPG agent implementation. + +### Using Intel AI Lab's Coach +Some of the features required for AMC are not yet in the official [Coach](https://github.com/NervanaSystems/coach) release, so you should use the `master` branch. +Therefore, follow Coach's [installation instructions](https://github.com/NervanaSystems/coach#installation) for a development environment, and use the `master` branch. +<br> +Coach uses TensorFlow and Distiller uses PyTorch, and the two frameworks do not share GPUs well. The easiest work-around for this is to execute Coach code (the RL agent) on the CPU and Distiller code on the GPU(s). +<br> +To do this, please uninstall TensorFlow:<br> +`$ pip uninstall tensorflow tensorflow-gpu` +<br> +and then reinstall TensorFlow: <br> +`$ pip install tensorflow ` + +### Using MIT HAN Lab's DDPG agent +We integrated MIT HAN's Lab's AMC DDPG agent directly into the code base, so there is no explicit effort required to use it. + +## Notable algorithm details + +### Feature-map reconstruction + +### DDPG Agent +- AMC uses truncated normal distribution for exploration policy +- AMC uses reward shaping + +## References + [1] AMC: AutoML for Model Compression and Acceleration on Mobile Devices.<br> Yihui He, Ji Lin, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han.<br> In Proceedings of the European Conference on Computer Vision (ECCV), 2018.<br> diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py index 46d633c265e109107ff5e6e30e277727ad5d7ece..d30473e27b082e6caa26ff68d4f6d87129b72a8b 100755 --- a/examples/auto_compression/amc/amc.py +++ b/examples/auto_compression/amc/amc.py @@ -15,7 +15,7 @@ # """ -$ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc --amc-procol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=private -j=1 +$ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc --amc-procol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=hanlab -j=1 """ @@ -143,9 +143,9 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo env2 = create_environment() steps_per_episode = env1.steps_per_episode rl.solve(env1, env2) - elif args.amc_rllib == "private": - from rl_libs.private import private_if - rl = private_if.RlLibInterface() + elif args.amc_rllib == "hanlab": + from rl_libs.hanlab import hanlab_if + rl = hanlab_if.RlLibInterface() args.observation_len = len(Observation._fields) rl.solve(env1, args) elif args.amc_rllib == "coach": diff --git a/examples/auto_compression/amc/amc_args.py b/examples/auto_compression/amc/amc_args.py index 38ecb0af5e79e4279ad8aaffcae89f4771cba1af..2238691a02976a5bc6f0be8b8cded28ee151e073 100755 --- a/examples/auto_compression/amc/amc_args.py +++ b/examples/auto_compression/amc/amc_args.py @@ -46,7 +46,7 @@ def add_automl_args(argparser): default="l1-rank", help="The pruning method") group.add_argument('--amc-rllib', choices=["coach", "spinningup", - "private", + "hanlab", "random"], default=None, help="Choose which RL library to use") group.add_argument('--amc-group-size', type=int, default=1, diff --git a/examples/auto_compression/amc/auto_compression_channels.yaml b/examples/auto_compression/amc/auto_compression_channels.yaml index 558c9754f1036454074ade0057c93c723764d963..737c3e492ee488af4221d4591c94cd0e577cbb1e 100755 --- a/examples/auto_compression/amc/auto_compression_channels.yaml +++ b/examples/auto_compression/amc/auto_compression_channels.yaml @@ -1,6 +1,6 @@ # rl_lib: # # Choose which RL library to use: Coach from Intel AI Lab, or Spinup from OpenAI -# name: private # one of: coach, spinningup, private +# name: hanlab # one of: coach, spinningup, hanlab network: mobilenet: diff --git a/examples/auto_compression/amc/auto_compression_filters.yaml b/examples/auto_compression/amc/auto_compression_filters.yaml index c80d59c43074104db37d9d43985c883c3d685040..16293488f626e447e51763fe94c33896ffe0b2ce 100755 --- a/examples/auto_compression/amc/auto_compression_filters.yaml +++ b/examples/auto_compression/amc/auto_compression_filters.yaml @@ -1,6 +1,6 @@ # rl_lib: # # Choose which RL library to use: Coach from Intel AI Lab, or Spinup from OpenAI -# name: private # one of: coach, spinningup, private +# name: hanlab # one of: coach, spinningup, hanlab network: mobilenet: diff --git a/examples/auto_compression/amc/jupyter/amc_plain20.ipynb b/examples/auto_compression/amc/jupyter/amc_plain20.ipynb index 3ccfb8c76fc5b5b33d757af1eee7341f7b79b5df..e58ec546db892c9e263b6f4c25b4d58d97c45192 100644 --- a/examples/auto_compression/amc/jupyter/amc_plain20.ipynb +++ b/examples/auto_compression/amc/jupyter/amc_plain20.ipynb @@ -79,7 +79,7 @@ "\n", "The command-line is provided below:\n", " \n", - " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=private -j=1\n", + " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=hanlab -j=1\n", " \n", "Each colored line represents one experiment execution instance. We plot the search-Top1 score of discovered networks as the RL-based AMC system learns to find better compressed networks. You might be impressed by:\n", "* The variability in behavior, which is typical for RL algorithms.\n", @@ -358,7 +358,7 @@ "source": [ "## Using a different reward function\n", "\n", - " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private-punish compress_classifier.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=punish-agent --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=../automated_deep_compression/auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=private -j=1" + " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private-punish compress_classifier.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=punish-agent --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=../automated_deep_compression/auto_compression_channels.yaml --evs=0.5 --etrs=0.5 --amc-rllib=hanlab -j=1" ] }, { @@ -448,7 +448,7 @@ " validation=5000\n", " test=750\n", "\n", - " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=1.0 --etrs=0.01 --amc-rllib=private -j=1\n", + " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=1.0 --etrs=0.01 --amc-rllib=hanlab -j=1\n", " \n", " time python parallel-finetune.py --scan-dir=${AMC_EXP_PATH}/plain20-ddpg-private/2019.08.01-181040 --arch=plain20_cifar --lr=0.1 --vs=0 -p=50 --compress=plain20_fine_tune.yaml ${CIFAR10_PATH} -j=1 --epochs=60 --output-csv=ft_60epoch_results.csv --processes=16" ] @@ -491,7 +491,7 @@ "source": [ "Now repeating the original experiment, with checkpoint savings so we can fine-tune the best ones\n", "\n", - " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.01 --amc-rllib=private -j=1 --amc-save-chkpts\n", + " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/plain20-ddpg-private amc.py --arch=plain20_cifar ${CIFAR10_PATH} --resume=${CHECKPOINTS_PATH}/checkpoint.plain20_cifar.pth.tar --lr=0.05 --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --evs=0.5 --etrs=0.01 --amc-rllib=hanlab -j=1 --amc-save-chkpts\n", " \n", " time python ../../classifier_compression/parallel-finetune.py --scan-dir=${AMC_EXP_PATH}/plain20-ddpg-private/2019.08.03-000628 --arch=plain20_cifar --lr=0.1 --vs=0 -p=50 --compress=../plain20_fine_tune.yaml ${CIFAR10_PATH} -j=1 --epochs=60 --output-csv=ft_60epoch_results.csv --processes=16 --top-performing-chkpts" ] diff --git a/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb b/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb index 4124e146c7ed022fa8df591b5e5854953032e139..4232c7baed315ff371399cce74e2fb7407ecf058 100644 --- a/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb +++ b/examples/auto_compression/amc/jupyter/amc_resnet20.ipynb @@ -79,7 +79,7 @@ "\n", "The command-line is provided below:\n", " \n", - " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/resnet20-ddpg-private amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=private -j=1\n", + " time python3 ../../classifier_compression/multi-run.py ${AMC_EXP_PATH}/resnet20-ddpg-private amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --amc-protocol=mac-constrained --amc-action-range 0.05 1.0 --amc-target-density=0.5 -p=50 --etes=0.075 --amc-ft-epochs=0 --amc-prune-pattern=channels --amc-prune-method=fm-reconstruction --amc-agent-algo=DDPG --amc-cfg=auto_compression_channels.yaml --amc-rllib=hanlab -j=1\n", " \n", "Each colored line represents one experiment execution instance. We plot the search-Top1 score of discovered networks as the RL-based AMC system learns to find better compressed networks. You might be impressed by:\n", "* The variability in behavior, which is typical for RL algorithms.\n", @@ -233,13 +233,13 @@ "pycharm": { "stem_cell": { "cell_type": "raw", + "source": [], "metadata": { "collapsed": false - }, - "source": [] + } } } }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/auto_compression/amc/rl_libs/hanlab/README.md b/examples/auto_compression/amc/rl_libs/hanlab/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9aebcd9ea12a4be30774068d0a10f61d00e00fd9 --- /dev/null +++ b/examples/auto_compression/amc/rl_libs/hanlab/README.md @@ -0,0 +1,15 @@ +The code in this directory originates from HAN Lab's [AMC github repository](https://github.com/mit-han-lab/amc-release). + +We copied the DDPG files from [HAN Lab's github](https://github.com/mit-han-lab/amc-release) to `distiller/examples/auto_compression/amc/rl_libs/private/`.<br> +Specifically: +- `mit-han-lab/amc-release/tree/master/lib/agent.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private` +- `mit-han-lab/amc-release/tree/master/lib/memory.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private` +- `mit-han-lab/amc-release/tree/master/lib/utils.py` ==> `distiller/examples/auto_compression/amc/rl_libs/private` + +Function `train()` was copied from `mit-han-lab/amc-release/tree/master/lib/amc_search.py` to the new file in `distiller/examples/auto_compression/amc/rl_libs/private/agent.py`. + +Some non-functional changes were introduced in order for this to compile under Distiller. + +The MIT license was copies to distiller/licenses/hanlab-amc-license.txt. + + \ No newline at end of file diff --git a/examples/auto_compression/amc/rl_libs/private/__init__.py b/examples/auto_compression/amc/rl_libs/hanlab/__init__.py similarity index 100% rename from examples/auto_compression/amc/rl_libs/private/__init__.py rename to examples/auto_compression/amc/rl_libs/hanlab/__init__.py diff --git a/examples/auto_compression/amc/rl_libs/hanlab/agent.py b/examples/auto_compression/amc/rl_libs/hanlab/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac2045c64bb535e7168fd75a9373672c8b128b1 --- /dev/null +++ b/examples/auto_compression/amc/rl_libs/hanlab/agent.py @@ -0,0 +1,298 @@ +# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" +# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han +# {jilin, songhan}@mit.edu + +import numpy as np +import torch +import torch.nn as nn +from torch.optim import Adam +import os +from .memory import SequentialMemory +from .utils import to_numpy, to_tensor +from copy import deepcopy + +criterion = nn.MSELoss() +USE_CUDA = torch.cuda.is_available() + + +class Actor(nn.Module): + def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300): + super(Actor, self).__init__() + self.fc1 = nn.Linear(nb_states, hidden1) + self.fc2 = nn.Linear(hidden1, hidden2) + self.fc3 = nn.Linear(hidden2, nb_actions) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + out = self.relu(out) + out = self.fc3(out) + out = self.sigmoid(out) + return out + + +class Critic(nn.Module): + def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300): + super(Critic, self).__init__() + self.fc11 = nn.Linear(nb_states, hidden1) + self.fc12 = nn.Linear(nb_actions, hidden1) + self.fc2 = nn.Linear(hidden1, hidden2) + self.fc3 = nn.Linear(hidden2, 1) + self.relu = nn.ReLU() + + def forward(self, xs): + x, a = xs + out = self.fc11(x) + self.fc12(a) + out = self.relu(out) + out = self.fc2(out) + out = self.relu(out) + out = self.fc3(out) + return out + + +class DDPG(object): + def __init__(self, nb_states, nb_actions, args): + + self.nb_states = nb_states + self.nb_actions = nb_actions + + # Create Actor and Critic Network + net_cfg = { + 'hidden1': args.hidden1, + 'hidden2': args.hidden2, + } + self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) + self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) + self.actor_optim = Adam(self.actor.parameters(), lr=args.lr_a) + + self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) + self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) + self.critic_optim = Adam(self.critic.parameters(), lr=args.lr_c) + + self.hard_update(self.actor_target, self.actor) # Make sure target is with the same weight + self.hard_update(self.critic_target, self.critic) + + # Create replay buffer + self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) + + # Hyper-parameters + self.batch_size = args.bsize + self.tau = args.tau + self.discount = args.discount + self.depsilon = 1.0 / args.epsilon + self.lbound = 0. # args.lbound + self.rbound = 1. # args.rbound + + # noise + self.init_delta = args.init_delta + self.delta_decay = args.delta_decay + self.warmup = args.warmup + + # + self.epsilon = 1.0 + self.is_training = True + + # + if USE_CUDA: self.cuda() + + # moving average baseline + self.moving_average = None + self.moving_alpha = 0.5 # based on batch, so small + + def update_policy(self): + # Sample batch + state_batch, action_batch, reward_batch, \ + next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) + + # normalize the reward + batch_mean_reward = np.mean(reward_batch) + if self.moving_average is None: + self.moving_average = batch_mean_reward + else: + self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average) + reward_batch -= self.moving_average + + # Prepare for the target q batch + with torch.no_grad(): + next_q_values = self.critic_target([ + to_tensor(next_state_batch), + self.actor_target(to_tensor(next_state_batch)), + ]) + + target_q_batch = to_tensor(reward_batch) + \ + self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values + + # Critic update + self.critic.zero_grad() + + q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)]) + + value_loss = criterion(q_batch, target_q_batch) + value_loss.backward() + self.critic_optim.step() + + # Actor update + self.actor.zero_grad() + + policy_loss = -self.critic([ + to_tensor(state_batch), + self.actor(to_tensor(state_batch)) + ]) + + policy_loss = policy_loss.mean() + policy_loss.backward() + self.actor_optim.step() + + # Target update + self.soft_update(self.actor_target, self.actor) + self.soft_update(self.critic_target, self.critic) + + def eval(self): + self.actor.eval() + self.actor_target.eval() + self.critic.eval() + self.critic_target.eval() + + def cuda(self): + self.actor.cuda() + self.actor_target.cuda() + self.critic.cuda() + self.critic_target.cuda() + + def observe(self, r_t, s_t, s_t1, a_t, done): + if self.is_training: + self.memory.append(s_t, a_t, r_t, done) # save to memory + # self.s_t = s_t1 + + def random_action(self): + action = np.random.uniform(self.lbound, self.rbound, self.nb_actions) + return action + + def select_action(self, s_t, episode): + # assert episode >= self.warmup, 'Episode: {} warmup: {}'.format(episode, self.warmup) + action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0) + delta = self.init_delta * (self.delta_decay ** (episode - self.warmup)) + # action += self.is_training * max(self.epsilon, 0) * self.random_process.sample() + action = self.sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta) + action = np.clip(action, self.lbound, self.rbound) + + # self.a_t = action + return action + + def reset(self, obs): + pass + # self.s_t = obs + # self.random_process.reset_states() + + def load_weights(self, output): + if output is None: return + + self.actor.load_state_dict( + torch.load('{}/actor.pkl'.format(output)) + ) + + self.critic.load_state_dict( + torch.load('{}/critic.pkl'.format(output)) + ) + + def save_model(self, output): + torch.save( + self.actor.state_dict(), + '{}/actor.pkl'.format(output) + ) + torch.save( + self.critic.state_dict(), + '{}/critic.pkl'.format(output) + ) + + def soft_update(self, target, source): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_( + target_param.data * (1.0 - self.tau) + param.data * self.tau + ) + + def hard_update(self, target, source): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_(param.data) + + def sample_from_truncated_normal_distribution(self, lower, upper, mu, sigma, size=1): + from scipy import stats + return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size) + + +def train(num_episode, agent, env, output, warmup): + text_writer = open(os.path.join(output, 'sh_log.txt'), 'w') + print('=> Output path: {}...'.format(output)) + + agent.is_training = True + step = episode = episode_steps = 0 + episode_reward = 0. + observation = None + T = [] # trajectory + while episode < num_episode: # counting based on episode + # reset if it is the start of episode + if observation is None: + observation = deepcopy(env.reset()) + agent.reset(observation) + + # agent pick action ... + if episode <= warmup: + action = agent.random_action() + else: + action = agent.select_action(observation, episode=episode) + + # env response with next_observation, reward, terminate_info + observation2, reward, done, info = env.step(action) + env.render() + observation2 = deepcopy(observation2) + + T.append([reward, deepcopy(observation), deepcopy(observation2), action, done]) + + # [optional] save intermideate model + #if episode % int(num_episode / 3) == 0: + # agent.save_model(output) + + # update + step += 1 + episode_steps += 1 + episode_reward += reward + observation = deepcopy(observation2) + + if done: # end of episode + print('#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(episode, episode_reward, + info['accuracy'], + info['compress_ratio'])) + text_writer.write( + '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(episode, episode_reward, + info['accuracy'], + info['compress_ratio'])) + final_reward = T[-1][0] + # print('final_reward: {}'.format(final_reward)) + # agent observe and update policy + for r_t, s_t, s_t1, a_t, done in T: + agent.observe(final_reward, s_t, s_t1, a_t, done) + if episode > warmup: + agent.update_policy() + + # reset + observation = None + episode_steps = 0 + episode_reward = 0. + episode += 1 + T = [] + + # tfwriter.add_scalar('reward/last', final_reward, episode) + # tfwriter.add_scalar('reward/best', env.best_reward, episode) + # tfwriter.add_scalar('info/accuracy', info['accuracy'], episode) + # tfwriter.add_scalar('info/compress_ratio', info['compress_ratio'], episode) + # tfwriter.add_text('info/best_policy', str(env.best_strategy), episode) + # # record the preserve rate for each layer + # for i, preserve_rate in enumerate(env.strategy): + # tfwriter.add_scalar('preserve_rate/{}'.format(i), preserve_rate, episode) + + text_writer.write('best reward: {}\n'.format(env.best_reward)) + #text_writer.write('best policy: {}\n'.format(env.best_strategy)) + text_writer.close() diff --git a/examples/auto_compression/amc/rl_libs/private/private_if.py b/examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py similarity index 90% rename from examples/auto_compression/amc/rl_libs/private/private_if.py rename to examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py index 2ca43cebad36d05bf02b08261d7eca7be9cd4849..0cbb4df6a58ca534eacad9033e01b525ca13dda8 100755 --- a/examples/auto_compression/amc/rl_libs/private/private_if.py +++ b/examples/auto_compression/amc/rl_libs/hanlab/hanlab_if.py @@ -14,7 +14,7 @@ # limitations under the License. # -from examples.auto_compression.amc.rl_libs.private.agent import DDPG, train +from examples.auto_compression.amc.rl_libs.hanlab.agent import DDPG, train import logging @@ -27,10 +27,10 @@ class ArgsContainer(object): class RlLibInterface(object): - """Interface to a private DDPG impelementation.""" + """Interface to a hanlab DDPG impelementation.""" def solve(self, env, args): - msglogger.info("AMC: Using private") + msglogger.info("AMC: Using hanlab") agent_args = ArgsContainer() agent_args.bsize = args.batch_size diff --git a/examples/auto_compression/amc/rl_libs/hanlab/memory.py b/examples/auto_compression/amc/rl_libs/hanlab/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..493740c6351bc19b7648e7ba5fce9735bb783de6 --- /dev/null +++ b/examples/auto_compression/amc/rl_libs/hanlab/memory.py @@ -0,0 +1,228 @@ +# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" +# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han +# {jilin, songhan}@mit.edu + +from __future__ import absolute_import +from collections import deque, namedtuple +import warnings +import random + +import numpy as np + +# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py + +# This is to be understood as a transition: Given `state0`, performing `action` +# yields `reward` and results in `state1`, which might be `terminal`. +Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1') + + +def sample_batch_indexes(low, high, size): + if high - low >= size: + # We have enough data. Draw without replacement, that is each index is unique in the + # batch. We cannot use `np.random.choice` here because it is horribly inefficient as + # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion. + # `random.sample` does the same thing (drawing without replacement) and is way faster. + r = range(low, high) + batch_idxs = random.sample(r, size) + else: + # Not enough data. Help ourselves with sampling from the range, but the same index + # can occur multiple times. This is not good and should be avoided by picking a + # large enough warm-up phase. + warnings.warn( + 'Not enough entries to sample without replacement. ' + 'Consider increasing your warm-up phase to avoid oversampling!') + batch_idxs = np.random.random_integers(low, high - 1, size=size) + assert len(batch_idxs) == size + return batch_idxs + + +class RingBuffer(object): + def __init__(self, maxlen): + self.maxlen = maxlen + self.start = 0 + self.length = 0 + self.data = [None for _ in range(maxlen)] + + def __len__(self): + return self.length + + def __getitem__(self, idx): + if idx < 0 or idx >= self.length: + raise KeyError() + return self.data[(self.start + idx) % self.maxlen] + + def append(self, v): + if self.length < self.maxlen: + # We have space, simply increase the length. + self.length += 1 + elif self.length == self.maxlen: + # No space, "remove" the first item. + self.start = (self.start + 1) % self.maxlen + else: + # This should never happen. + raise RuntimeError() + self.data[(self.start + self.length - 1) % self.maxlen] = v + + +def zeroed_observation(observation): + if hasattr(observation, 'shape'): + return np.zeros(observation.shape) + elif hasattr(observation, '__iter__'): + out = [] + for x in observation: + out.append(zeroed_observation(x)) + return out + else: + return 0. + + +class Memory(object): + def __init__(self, window_length, ignore_episode_boundaries=False): + self.window_length = window_length + self.ignore_episode_boundaries = ignore_episode_boundaries + + self.recent_observations = deque(maxlen=window_length) + self.recent_terminals = deque(maxlen=window_length) + + def sample(self, batch_size, batch_idxs=None): + raise NotImplementedError() + + def append(self, observation, action, reward, terminal, training=True): + self.recent_observations.append(observation) + self.recent_terminals.append(terminal) + + def get_recent_state(self, current_observation): + # This code is slightly complicated by the fact that subsequent observations might be + # from different episodes. We ensure that an experience never spans multiple episodes. + # This is probably not that important in practice but it seems cleaner. + state = [current_observation] + idx = len(self.recent_observations) - 1 + for offset in range(0, self.window_length - 1): + current_idx = idx - offset + current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False + if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): + # The previously handled observation was terminal, don't add the current one. + # Otherwise we would leak into a different episode. + break + state.insert(0, self.recent_observations[current_idx]) + while len(state) < self.window_length: + state.insert(0, zeroed_observation(state[0])) + return state + + def get_config(self): + config = { + 'window_length': self.window_length, + 'ignore_episode_boundaries': self.ignore_episode_boundaries, + } + return config + + +class SequentialMemory(Memory): + def __init__(self, limit, **kwargs): + super(SequentialMemory, self).__init__(**kwargs) + + self.limit = limit + + # Do not use deque to implement the memory. This data structure may seem convenient but + # it is way too slow on random access. Instead, we use our own ring buffer implementation. + self.actions = RingBuffer(limit) + self.rewards = RingBuffer(limit) + self.terminals = RingBuffer(limit) + self.observations = RingBuffer(limit) + + def sample(self, batch_size, batch_idxs=None): + if batch_idxs is None: + # Draw random indexes such that we have at least a single entry before each + # index. + batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) + batch_idxs = np.array(batch_idxs) + 1 + assert np.min(batch_idxs) >= 1 + assert np.max(batch_idxs) < self.nb_entries + assert len(batch_idxs) == batch_size + + # Create experiences + experiences = [] + for idx in batch_idxs: + terminal0 = self.terminals[idx - 2] if idx >= 2 else False + while terminal0: + # Skip this transition because the environment was reset here. Select a new, random + # transition and use this instead. This may cause the batch to contain the same + # transition twice. + idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] + terminal0 = self.terminals[idx - 2] if idx >= 2 else False + assert 1 <= idx < self.nb_entries + + # This code is slightly complicated by the fact that subsequent observations might be + # from different episodes. We ensure that an experience never spans multiple episodes. + # This is probably not that important in practice but it seems cleaner. + state0 = [self.observations[idx - 1]] + for offset in range(0, self.window_length - 1): + current_idx = idx - 2 - offset + current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False + if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): + # The previously handled observation was terminal, don't add the current one. + # Otherwise we would leak into a different episode. + break + state0.insert(0, self.observations[current_idx]) + while len(state0) < self.window_length: + state0.insert(0, zeroed_observation(state0[0])) + action = self.actions[idx - 1] + reward = self.rewards[idx - 1] + terminal1 = self.terminals[idx - 1] + + # Okay, now we need to create the follow-up state. This is state0 shifted on timestep + # to the right. Again, we need to be careful to not include an observation from the next + # episode if the last state is terminal. + state1 = [np.copy(x) for x in state0[1:]] + state1.append(self.observations[idx]) + + assert len(state0) == self.window_length + assert len(state1) == len(state0) + experiences.append(Experience(state0=state0, action=action, reward=reward, + state1=state1, terminal1=terminal1)) + assert len(experiences) == batch_size + return experiences + + def sample_and_split(self, batch_size, batch_idxs=None): + experiences = self.sample(batch_size, batch_idxs) + + state0_batch = [] + reward_batch = [] + action_batch = [] + terminal1_batch = [] + state1_batch = [] + for e in experiences: + state0_batch.append(e.state0) + state1_batch.append(e.state1) + reward_batch.append(e.reward) + action_batch.append(e.action) + terminal1_batch.append(0. if e.terminal1 else 1.) + + # Prepare and validate parameters. + state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1) + state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1) + terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1) + reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1) + action_batch = np.array(action_batch, 'double').reshape(batch_size, -1) + + return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch + + def append(self, observation, action, reward, terminal, training=True): + super(SequentialMemory, self).append(observation, action, reward, terminal, training=training) + + # This needs to be understood as follows: in `observation`, take `action`, obtain `reward` + # and weather the next state is `terminal` or not. + if training: + self.observations.append(observation) + self.actions.append(action) + self.rewards.append(reward) + self.terminals.append(terminal) + + @property + def nb_entries(self): + return len(self.observations) + + def get_config(self): + config = super(SequentialMemory, self).get_config() + config['limit'] = self.limit + return config diff --git a/examples/auto_compression/amc/rl_libs/hanlab/utils.py b/examples/auto_compression/amc/rl_libs/hanlab/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdbb05012d64d8115a67317859d2f5876b7f0ab --- /dev/null +++ b/examples/auto_compression/amc/rl_libs/hanlab/utils.py @@ -0,0 +1,261 @@ +# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" +# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han +# {jilin, songhan}@mit.edu + +import os +import torch +import time +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + if self.count > 0: + self.avg = self.sum / self.count + + def accumulate(self, val, n=1): + self.sum += val + self.count += n + if self.count > 0: + self.avg = self.sum / self.count + + +class TextLogger(object): + """Write log immediately to the disk""" + def __init__(self, filepath): + self.f = open(filepath, 'w') + self.fid = self.f.fileno() + self.filepath = filepath + + def close(self): + self.f.close() + + def write(self, content): + self.f.write(content) + self.f.flush() + os.fsync(self.fid) + + def write_buf(self, content): + self.f.write(content) + + def print_and_write(self, content): + print(content) + self.write(content+'\n') + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + batch_size = target.size(0) + num = output.size(1) + target_topk = [] + appendices = [] + for k in topk: + if k <= num: + target_topk.append(k) + else: + appendices.append([0.0]) + topk = target_topk + maxk = max(topk) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + appendices + + +def to_numpy(var): + use_cuda = torch.cuda.is_available() + return var.cpu().data.numpy() if use_cuda else var.data.numpy() + + +def to_tensor(ndarray, requires_grad=False): # return a float tensor by default + tensor = torch.from_numpy(ndarray).float() # by default does not require grad + if requires_grad: + tensor.requires_grad_() + return tensor.cuda() if torch.cuda.is_available() else tensor + + +def measure_layer_for_pruning(layer, x): + def get_layer_type(layer): + layer_str = str(layer) + return layer_str[:layer_str.find('(')].strip() + + def get_layer_param(model): + import operator + import functools + + return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()]) + + multi_add = 1 + type_name = get_layer_type(layer) + + # ops_conv + if type_name in ['Conv2d']: + out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / + layer.stride[0] + 1) + out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / + layer.stride[1] + 1) + layer.flops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ + layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add + layer.params = get_layer_param(layer) + # ops_linear + elif type_name in ['Linear']: + weight_ops = layer.weight.numel() * multi_add + bias_ops = layer.bias.numel() + layer.flops = weight_ops + bias_ops + layer.params = get_layer_param(layer) + return + + +def least_square_sklearn(X, Y): + from sklearn.linear_model import LinearRegression + reg = LinearRegression(fit_intercept=False) + reg.fit(X, Y) + return reg.coef_ + + +def get_output_folder(parent_dir, env_name): + """Return save folder. + Assumes folders in the parent_dir have suffix -run{run + number}. Finds the highest run number and sets the output folder + to that number + 1. This is just convenient so that if you run the + same script multiple times tensorboard can plot all of the results + on the same plots with different names. + Parameters + ---------- + parent_dir: str + Path of the directory containing all experiment runs. + Returns + ------- + parent_dir/run_dir + Path to this run's save directory. + """ + os.makedirs(parent_dir, exist_ok=True) + experiment_id = 0 + for folder_name in os.listdir(parent_dir): + if not os.path.isdir(os.path.join(parent_dir, folder_name)): + continue + try: + folder_name = int(folder_name.split('-run')[-1]) + if folder_name > experiment_id: + experiment_id = folder_name + except: + pass + experiment_id += 1 + + parent_dir = os.path.join(parent_dir, env_name) + parent_dir = parent_dir + '-run{}'.format(experiment_id) + os.makedirs(parent_dir, exist_ok=True) + return parent_dir + + + +# Custom progress bar +_, term_width = os.popen('stty size', 'r').read().split() +term_width = int(term_width) +TOTAL_BAR_LENGTH = 40. +last_time = time.time() +begin_time = last_time + + +def progress_bar(current, total, msg=None): + def format_time(seconds): + days = int(seconds / 3600 / 24) + seconds = seconds - days * 3600 * 24 + hours = int(seconds / 3600) + seconds = seconds - hours * 3600 + minutes = int(seconds / 60) + seconds = seconds - minutes * 60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds * 1000) + + f = '' + i = 1 + if days > 0: + f += str(days) + 'D' + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + 'h' + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + 'm' + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + 's' + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + 'ms' + i += 1 + if f == '': + f = '0ms' + return f + + global last_time, begin_time + if current == 0: + begin_time = time.time() # Reset for new bar. + + cur_len = int(TOTAL_BAR_LENGTH*current/total) + rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 + + sys.stdout.write(' [') + for i in range(cur_len): + sys.stdout.write('=') + sys.stdout.write('>') + for i in range(rest_len): + sys.stdout.write('.') + sys.stdout.write(']') + + cur_time = time.time() + step_time = cur_time - last_time + last_time = cur_time + tot_time = cur_time - begin_time + + L = [] + L.append(' Step: %s' % format_time(step_time)) + L.append(' | Tot: %s' % format_time(tot_time)) + if msg: + L.append(' | ' + msg) + + msg = ''.join(L) + sys.stdout.write(msg) + for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): + sys.stdout.write(' ') + + # Go back to the center of the bar. + for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): + sys.stdout.write('\b') + sys.stdout.write(' %d/%d ' % (current+1, total)) + + if current < total-1: + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + sys.stdout.flush() + +# logging +def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) +def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) +def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) +def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) +def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) +def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) +def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) +def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) \ No newline at end of file diff --git a/licenses/hanlab-amc-license.txt b/licenses/hanlab-amc-license.txt new file mode 100755 index 0000000000000000000000000000000000000000..8d5cc049cf16d6019ec9a2fd2d7d41797b1b3f3b --- /dev/null +++ b/licenses/hanlab-amc-license.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 MIT_Han_Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.