diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py
index 215f6c25fb993f8151d5e9ba1c2b5bd8518e919b..80bca3a42807daecaafeb74e3d28937943822e61 100755
--- a/examples/automated_deep_compression/ADC.py
+++ b/examples/automated_deep_compression/ADC.py
@@ -226,14 +226,14 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
     # Create a dictionary of parameters that Coach will handover to DistillerWrapperEnvironment
     # Once it creates it.
     services = distiller.utils.MutableNamedTuple({
-                'validate_fn': validate_fn,
-                'save_checkpoint_fn': save_checkpoint_fn,
-                'train_fn': train_fn})
+            'validate_fn': validate_fn,
+            'save_checkpoint_fn': save_checkpoint_fn,
+            'train_fn': train_fn})
 
     app_args = distiller.utils.MutableNamedTuple({
-                'dataset': dataset,
-                'arch': arch,
-                'optimizer_data': optimizer_data})
+            'dataset': dataset,
+            'arch': arch,
+            'optimizer_data': optimizer_data})
 
     amc_cfg = distiller.utils.MutableNamedTuple({
             'protocol': args.amc_protocol,
@@ -243,7 +243,8 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
             'action_range': action_range,
             'conv_cnt': conv_cnt,
             'reward_frequency': args.amc_reward_frequency,
-            'ft_frequency': args.amc_ft_frequency})
+            'ft_frequency': args.amc_ft_frequency,
+            'pruning_pattern': "filters"}) # "channels"}) #
 
     #net_wrapper = NetworkWrapper(model, app_args, services)
     #return sample_networks(net_wrapper, services)
@@ -344,9 +345,21 @@ resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight
                    "module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight",
                    "module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"]
 
+plain20_params =  ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight",
+                   "module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight",
+                   "module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight",
+                   "module.layer2.0.conv1.weight", "module.layer2.0.conv2.weight",
+                   "module.layer2.1.conv1.weight", "module.layer2.1.conv2.weight",
+                   "module.layer2.2.conv1.weight", "module.layer2.2.conv2.weight",
+                   "module.layer3.0.conv1.weight", "module.layer3.0.conv2.weight",
+                   "module.layer3.1.conv1.weight", "module.layer3.1.conv2.weight",
+                   "module.layer3.2.conv1.weight", "module.layer3.2.conv2.weight"]
+
+
 resnet50_layers = [param[:-len(".weight")] for param in resnet50_params]
 resnet20_layers = [param[:-len(".weight")] for param in resnet20_params]
 resnet56_layers = [param[:-len(".weight")] for param in resnet56_params]
+plain20_layers = [param[:-len(".weight")] for param in plain20_params]
 
 
 class NetworkWrapper(object):
@@ -375,6 +388,8 @@ class NetworkWrapper(object):
             resnet_layers = resnet56_layers
         elif self.app_args.arch == "resnet50":
             resnet_layers = resnet50_layers
+        elif self.app_args.arch == "plain20_cifar":
+            resnet_layers = plain20_layers
         return collect_conv_details(model, self.app_args.dataset, True, resnet_layers)
 
     def num_layers(self):
@@ -426,16 +441,16 @@ class NetworkWrapper(object):
         conv_pname = layer.name + ".weight"
         conv_p = distiller.model_find_param(self.model, conv_pname)
 
-        msglogger.info("ADC: removing %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname))
+        msglogger.info("ADC: trying to remove %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname))
 
         if prune_what == "channels":
             calculate_sparsity = distiller.sparsity_ch
-            remove_structures = distiller.remove_channels
+            remove_structures_fn = distiller.remove_channels
             group_type = "Channels"
         elif prune_what == "filters":
             calculate_sparsity = distiller.sparsity_3D
             group_type = "Filters"
-            remove_structures = distiller.remove_filters
+            remove_structures_fn = distiller.remove_filters
         else:
             raise ValueError("unsupported structure {}".format(prune_what))
         # Create a channel-ranking pruner
@@ -446,12 +461,12 @@ class NetworkWrapper(object):
 
         if (self.zeros_mask_dict[conv_pname].mask is None or
             calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0):
-            msglogger.info("remove_structures: aborting because there are no channels to prune")
+            msglogger.info("remove_structures: aborting because there are no structures to prune")
             return 0
 
         # Use the mask to prune
         self.zeros_mask_dict[conv_pname].apply_mask(conv_p)
-        remove_structures(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None)
+        remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None)
         conv_p = distiller.model_find_param(self.model, conv_pname)
         return 1 - (self.get_layer_macs(layer) / macs_before)
 
@@ -476,8 +491,6 @@ class NetworkWrapper(object):
 
 
 class DistillerWrapperEnvironment(gym.Env):
-    metadata = {'render.modes': ['human']}
-
     def __init__(self, model, app_args, amc_cfg, services):
         self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
         self.tflogger = distiller.data_loggers.TensorBoardLogger(msglogger.logdir)
@@ -573,23 +586,26 @@ class DistillerWrapperEnvironment(gym.Env):
     def step(self, pruning_action):
         """Take a step, given an action.
 
-        The action represents the desired sparsity.
+        The action represents the desired sparsity for the "current" layer.
         This function is invoked by the Agent.
         """
         msglogger.info("env.step - current_layer_id={}  episode={}".format(self.current_layer_id, self.episode))
+        pruning_action = pruning_action[0]
         msglogger.info("\tAgent pruning_action={}".format(pruning_action))
+        self.agent_action_history.append(pruning_action)
 
         if is_using_continuous_action_space(self.amc_cfg.agent_algo):
-            pruning_action = np.clip(pruning_action[0], self.action_low, self.action_high)
+            pruning_action = np.clip(pruning_action, self.action_low, self.action_high)
         else:
             # Divide the action space into 10 discrete levels (0%, 10%, 20%,....90% sparsity)
             pruning_action = pruning_action / 10
         msglogger.info("\tAgent clipped pruning_action={}".format(pruning_action))
-        self.agent_action_history.append(pruning_action)
+
         if self.amc_cfg.action_constrain_fn is not None:
             pruning_action = self.amc_cfg.action_constrain_fn(self, pruning_action=pruning_action)
             msglogger.info("Constrained pruning_action={}".format(pruning_action))
 
+        # Calculate the final compression rate
         total_macs_before, _ = self.net_wrapper.get_model_resources_requirements(self.model)
         layer_macs = self.net_wrapper.get_layer_macs(self.current_layer())
         msglogger.info("\tlayer_macs={:.2f}".format(layer_macs / self.dense_model_macs))
@@ -599,7 +615,7 @@ class DistillerWrapperEnvironment(gym.Env):
         if pruning_action > 0:
             pruning_action = self.net_wrapper.remove_structures(self.current_layer_id,
                                                                 fraction_to_prune=pruning_action,
-                                                                prune_what="filters")
+                                                                prune_what=self.amc_cfg.pruning_pattern)
         else:
             pruning_action = 0