diff --git a/distiller/quantization/ptq_coordinate_search.py b/distiller/quantization/ptq_coordinate_search.py
index ea8ecf815bb56d32b3d53b6763a426546c75a676..12036a580c78582c5d4b0fbcbfe8b02d291a75b6 100644
--- a/distiller/quantization/ptq_coordinate_search.py
+++ b/distiller/quantization/ptq_coordinate_search.py
@@ -44,6 +44,12 @@ import argparse
 msglogger = logging.getLogger()
 
 
+def _make_non_parallel_copy(model):
+    if any([isinstance(m, nn.DataParallel) for m in model.modules()]):
+        return distiller.make_non_parallel_copy(model)
+    return model
+
+
 def quant_params_dict2vec(p_dict, search_clipping=False):
     """
     Convert quantization params dictionary returned by post-train quantizer to a numpy array that can be used
@@ -107,6 +113,7 @@ _INIT_MODES = {
 
 
 def _init_mode_from_str(init_mode_str):
+    init_mode_str = init_mode_str.upper()
     if init_mode_str not in _INIT_MODES:
         raise ValueError('Unsupported init mode \'%s\'. '
                          'The supported init modes are: %s.' % (init_mode_str, _INIT_MODES))
@@ -168,13 +175,13 @@ def get_input_for_layer(model, layer_name, eval_fn):
 
     handle = layer.register_forward_pre_hook(hook_layer_input)
     eval_fn(model)
-    assert len(layer_inputs) == 1
     handle.remove()
-    return layer_inputs[0]
+    return torch.cat(layer_inputs)
 
 
 def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode=ClipMode.NONE,
-                                   init_method='Powell', eval_fn=None, search_clipping=False):
+                                   init_method='Powell', eval_fn=None, search_clipping=False,
+                                   run_device='cpu'):
     """
     Initializes a layer's linear quant parameters.
     This is done to set the scipy.optimize.minimize initial guess.
@@ -215,15 +222,16 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m
 
     if callable(init_mode):
         input_for_layer = get_input_for_layer(original_model, layer_name, eval_fn)
-        quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_method,
-                                             search_clipping=search_clipping)
+        quantized_layer = optimize_for_layer(layer.to(device=run_device), quantized_layer.to(device=run_device),
+                                             init_mode, input_for_layer, init_method, search_clipping=search_clipping)
+        del input_for_layer
 
     distiller.model_setattr(quantizer.model, denorm_layer_name, quantized_layer)
     quantizer.model.eval()
 
 
 def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode,
-                             init_method='Powell', search_clipping=False):
+                             init_method='Powell', search_clipping=False, run_device='cpu'):
     """
     Initializes all linear quantization parameters of the model.
     Args:
@@ -239,8 +247,9 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in
         init_method: See `init_layer_linear_qaunt_params`.
         search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor
     """
-    original_model = distiller.make_non_parallel_copy(original_model)
-    layers_topological_order = SummaryGraph(original_model, dummy_input).layers_topological_order()
+    non_parallel_model = _make_non_parallel_copy(original_model).to(
+        device=run_device if callable(init_mode) else 'cpu')
+    layers_topological_order = SummaryGraph(non_parallel_model, dummy_input).layers_topological_order()
     q_named_modules = OrderedDict(quantizer.model.named_modules())
     for module_name in layers_topological_order:
         # check to see if it was quantized:
@@ -249,11 +258,12 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in
             continue
         module_init_mode = init_mode[module_name] if isinstance(init_mode, dict) else init_mode
         msglogger.debug('Initializing layer \'%s\' using %s mode' % (module_name, module_init_mode))
-        init_layer_linear_quant_params(quantizer, original_model, module_name, module_init_mode,
+        init_layer_linear_quant_params(quantizer, non_parallel_model, module_name, module_init_mode,
                                        init_method=init_method,
                                        eval_fn=eval_fn,
-                                       search_clipping=search_clipping)
-    del original_model
+                                       search_clipping=search_clipping, run_device=run_device)
+    if non_parallel_model != original_model:
+        del non_parallel_model
 
     quantizer._post_prepare_model()
     quantizer.model.eval()
@@ -271,7 +281,7 @@ def add_coordinate_search_args(parser: argparse.ArgumentParser):
                        help='Use scipy.optimize.basinhopping stochastic global minimum search.')
     group.add_argument('--lapq-basinhopping-niter', '--lapq-bh-niter', default=100,
                        help='Number of iterations for the basinhopping algorithm.')
-    group.add_argument('--lapq-init-mode', default='NONE', choices=list(_INIT_MODES),
+    group.add_argument('--lapq-init-mode', default='NONE', type=_init_mode_from_str,
                        help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES)))
     group.add_argument('--lapq-init-method', default='Powell',
                        help='If --lapq-init-mode was specified as L1/L2/L3, this specifies the method of '
@@ -358,16 +368,19 @@ def ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=None, method=
     if quantizer.prepared:
         raise ValueError('Expecting a quantizer for which prepare_model has not been called')
 
-    original_model = deepcopy(quantizer.model)
+    run_device = distiller.model_device(quantizer.model)
+
+    original_model = deepcopy(quantizer.model).cpu()
     original_model = fold_batch_norms(original_model, dummy_input)
 
     if not quantizer.model_activation_stats:
         msglogger.info('Collecting stats for model...')
-        model_temp = distiller.utils.make_non_parallel_copy(original_model)
+        model_temp = _make_non_parallel_copy(original_model).to(device=run_device)
         act_stats = collect_quant_stats(model_temp, eval_fn,
                                         inplace_runtime_check=True, disable_inplace_attrs=True,
                                         save_dir=getattr(msglogger, 'logdir', '.'))
-        del model_temp
+        if model_temp != original_model:
+            del model_temp
         quantizer.model_activation_stats = act_stats
         quantizer.model.quantizer_metadata['params']['model_activation_stats'] = act_stats
 
@@ -385,12 +398,13 @@ def ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=None, method=
 
     quantizer.prepare_model(dummy_input)
     quantizer.model.eval()
+    quantizer.model = quantizer.model.cpu()
 
     validate_quantization_settings(quantizer.model, search_clipping)
 
     msglogger.info("Initializing quantization parameters...")
     init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, init_method,
-                             search_clipping=search_clipping)
+                             search_clipping=search_clipping, run_device=run_device)
 
     msglogger.info("Evaluating initial quantization score...")
     best_data = {
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index c3bd293e17bfc6757638a50e1f54314b9aedc51d..19a8e83e36fd4e9eaa47544d1578b8073475f863 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -232,6 +232,7 @@ class Quantizer(object):
         if dummy_input is not None:
             summary_graph = distiller.SummaryGraph(self.model, dummy_input)
             self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)
+            del summary_graph
 
         model_device = distiller.model_device(self.model)
 
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index e02f5ce4656f2ebdb1f9537dacdf7ed4e1d9afaf..71a22048140663b208cda1f0684c0de605ddae3b 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -2023,6 +2023,7 @@ class PostTrainLinearQuantizer(Quantizer):
         # After BN folding model need to re-generate the adjacency map
         summary_graph = distiller.SummaryGraph(self.model, dummy_input)
         self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)
+        del summary_graph
 
         if not self.model_activation_stats:
             return
diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index bad898cb6ecd9e7bfb1f0851d834405d46b16bf5..4e270c2fb2d52da6ad16984c73fbdb0cbc0cc416 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -195,6 +195,8 @@ class SummaryGraph(object):
         self.add_macs_attr()
         self.add_footprint_attr()
         self.add_arithmetic_intensity_attr()
+        del trace
+        del graph
         del model_clone
 
     def __merge_pad_avgpool(self):