diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 2badfdc6a8b7288ecf4b8bd6a381bbe077c3dfcf..58891e8d4b8f3f4a84173a89041e49c20d755012 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -347,9 +347,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
                               self.params_max_q_val, inplace=True)
 
         self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None
+        device = self.w_scale.device
 
         if self.preset_act_stats:
-            self.register_buffer('accum_scale', self.in_0_scale * self.w_scale)
+            self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale)
             if self.per_channel_wts:
                 self.accum_scale = self.accum_scale.squeeze(dim=-1)
         else:
@@ -570,6 +571,34 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
         return output_scale / self.accum_scale, output_zero_point
 
 
+class FP16Wrapper(nn.Module):
+    """
+    A wrapper that replaces a module with a half precision version.
+
+    Args:
+        module (nn.Module): The module to be replaced.
+        convert_input (:obj:`bool`, optional): Specifies whether an input conversion
+            to fp16 is required for forward. Default: True.
+        return_fp32 (:obj:`bool`, optional): Specifies whether the output needs
+            to be converted back to fp32. Default: True.
+    """
+    def __init__(self, module: nn.Module, convert_input=True, return_fp32=True):
+        super(FP16Wrapper, self).__init__()
+        self.wrapped_module = module.half()
+        self.return_fp32 = return_fp32
+        self.convert_input_fp16 = convert_input
+
+    def forward(self, *input):
+        if self.convert_input_fp16:
+            input = distiller.convert_tensors_recursively_to(input, dtype=torch.float16)
+
+        result = self.wrapped_module(*input)
+        if self.return_fp32:
+            return distiller.convert_tensors_recursively_to(result, dtype=torch.float32)
+
+        return result
+
+
 class RangeLinearEmbeddingWrapper(nn.Module):
     def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None):
         if not isinstance(wrapped_module, nn.Embedding):
@@ -585,8 +614,10 @@ class RangeLinearEmbeddingWrapper(nn.Module):
         else:
             w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode)
 
-        self.register_buffer('w_scale', w_scale)
-        self.register_buffer('w_zero_point', w_zero_point)
+        device = wrapped_module.weight.device
+
+        self.register_buffer('w_scale', w_scale.to(device))
+        self.register_buffer('w_zero_point', w_zero_point.to(device))
         linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val,
                               self.max_q_val, inplace=True)
 
@@ -607,6 +638,7 @@ class PostTrainLinearQuantizer(Quantizer):
     Args:
         model (torch.nn.Module): Model to be quantized
         bits_activations/parameters/accum (int): Number of bits to be used when quantizing each tensor type
+        overrides (:obj:`OrderedDict`, optional): Overrides the layers quantization settings.
         clip_acts (bool): See RangeLinearQuantWrapper
         no_clip_layers (list): List of fully-qualified layer names for which activations clipping should not be done.
             A common practice is to not clip the activations of the last layer before softmax.
@@ -614,10 +646,14 @@ class PostTrainLinearQuantizer(Quantizer):
         per_channel_wts (bool): Set to True to enable per-channel quantization of weights (per output channel)
         model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary
             containing the stats. If None then stats will be calculated dynamically.
+        fp16 (bool): Set to True to convert modules to half precision.
+    Note:
+        If fp16 is set to True, all the layers (except those overriden in `overrides`) will be converted
+        to half precision, regardless of bits_activations/parameters/accum.
     """
     def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
                  overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=False, no_clip_layers=None,
-                 per_channel_wts=False, model_activation_stats=None):
+                 per_channel_wts=False, model_activation_stats=None, fp16=False):
         super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
                                                        bits_weights=bits_parameters, bits_bias=bits_accum,
                                                        overrides=overrides, train_with_fp_copy=False)
@@ -641,29 +677,37 @@ class PostTrainLinearQuantizer(Quantizer):
                                                     'mode': str(mode).split('.')[1], 'clip_acts': clip_acts,
                                                     'no_clip_layers': no_clip_layers,
                                                     'per_channel_wts': per_channel_wts}}
-        
-        def replace_param_layer(module, name, qbits_map):
+
+        def replace_param_layer(module, name, qbits_map,
+                                per_channel_wts=per_channel_wts,
+                                mode=mode,
+                                fp16=fp16):
+            if fp16:
+                return FP16Wrapper(module)
             norm_name = distiller.utils.normalize_module_name(name)
-            clip = clip_acts and norm_name not in self.no_clip_layers
+            clip = self.clip_acts and norm_name not in self.no_clip_layers
             return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts,
                                                      num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip,
                                                      per_channel_wts=per_channel_wts,
                                                      activation_stats=self.model_activation_stats.get(norm_name, None))
 
-        def replace_non_param_layer(wrapper_type, module, name, qbits_map):
+        def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16):
+            if fp16:
+                return FP16Wrapper(module)
             norm_name = distiller.utils.normalize_module_name(name)
             clip = self.clip_acts and norm_name not in self.no_clip_layers
             return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip,
                                 activation_stats=self.model_activation_stats.get(norm_name, None))
 
-        def replace_embedding(module, name, qbits_map):
+        def replace_embedding(module, name, qbits_map, fp16=fp16):
+            if fp16:
+                return FP16Wrapper(module, convert_input=False)
             norm_name = distiller.utils.normalize_module_name(name)
             return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode,
                                                stats=self.model_activation_stats.get(norm_name, None))
 
-
         self.clip_acts = clip_acts
-        self.no_clip_layers = [] or no_clip_layers
+        self.no_clip_layers = no_clip_layers or []
         self.model_activation_stats = model_activation_stats or {}
         self.bits_accum = bits_accum
         self.mode = mode