diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 87f85466a54ef1f67b1f435edaf8c62e57d3384b..6e3fdb18518414d00025f72186b0a58070dd0ed3 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -242,9 +242,11 @@ def add_post_train_quant_args(argparser):
     group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
                        help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
     group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
-                       help='Number of bits for quantization of activations')
+                       help='Number of bits for quantization of activations. Use 0 to not quantize activations. '
+                            'Default value is 8')
     group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
-                       help='Number of bits for quantization of weights')
+                       help='Number of bits for quantization of weights. Use 0 to not quantize weights. '
+                            'Default value is 8')
     group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
                        help='Number of bits for quantization of the accumulator')
     group.add_argument('--qe-clip-acts', '--qeca', type=clip_mode_str, default='none',
@@ -312,6 +314,15 @@ class RangeLinearQuantWrapper(nn.Module):
         self.output_quant_settings = QuantSettings(num_bits_acts, mode, clip_acts, clip_n_stds, clip_half_range, False)
         self.accum_quant_settings = QuantSettings(num_bits_accum, LinearQuantMode.SYMMETRIC,
                                                   ClipMode.NONE, None, False, False)
+
+        self.preset_act_stats = False
+        self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long))
+
+        # Activations not quantized - stop here
+        if num_bits_acts is None:
+            return
+
+        # Activations are quantized - setup quantization parameters
         if self.requires_quantized_inputs:
             self.inputs_quant_settings_overrides = OrderedDict()
             for k, v in input_overrides.items():
@@ -320,13 +331,12 @@ class RangeLinearQuantWrapper(nn.Module):
                     quant_settings = deepcopy(self.output_quant_settings)
                     quant_settings.clip_half_range = False
                 else:
-                    quant_settings = QuantSettings(v.pop('bits_activations', self.output_quant_settings.num_bits),
-                                                   verify_quant_mode(
-                                                       v.pop('mode', self.output_quant_settings.quant_mode)),
-                                                   verify_clip_mode(
-                                                       v.pop('clip_acts', self.output_quant_settings.clip_mode)),
-                                                   v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds),
-                                                   False, False)
+                    quant_settings = QuantSettings(
+                        v.pop('bits_activations', self.output_quant_settings.num_bits),
+                        verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)),
+                        verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)),
+                        v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds),
+                        False, False)
                     if v:
                         # Poor man's input checking on input overrides dict
                         raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys()))
@@ -369,10 +379,8 @@ class RangeLinearQuantWrapper(nn.Module):
         else:
             self.preset_act_stats = False
 
-        self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long))
-
     def named_acts_quant_params(self):
-        if self.preset_act_stats:
+        if self.output_quant_settings.num_bits is not None and self.preset_act_stats:
             # Output scale buffers are saved in the model only when stats are used
             yield 'output_scale', self.output_scale
             yield 'output_zero_point', self.output_zero_point
@@ -381,6 +389,13 @@ class RangeLinearQuantWrapper(nn.Module):
         if self.training:
             raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode")
 
+        if self.output_quant_settings.num_bits is None:
+            # Pass through
+            out = self.wrapped_module(*inputs)
+            if self.clip_half_range:
+                out = f.relu(out)
+            return out
+
         device = inputs[0].device
         for buffer_name, buffer in self._buffers.items():
             setattr(self, buffer_name, buffer.to(device))
@@ -498,6 +513,9 @@ class RangeLinearQuantWrapper(nn.Module):
         raise NotImplementedError
 
     def extra_repr(self):
+        if self.output_quant_settings.num_bits is None:
+            return 'output_quant_settings=Not_Quantized'
+
         tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings)
         tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings)
         tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs)
@@ -574,6 +592,9 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)):
             raise ValueError(self.__class__.__name__ + ' can wrap only Conv2D, Conv3D and Linear modules')
 
+        # If activations are not quantized, we do fake quantization of the parameters, that is - quant and de-quant
+        self.fake_quant_params = self.output_quant_settings.num_bits is None
+
         self.wts_quant_settings = QuantSettings(num_bits_params, mode, ClipMode.NONE, None, False, per_channel_wts)
 
         self.params_min_q_val, self.params_max_q_val = get_quantized_range(
@@ -591,20 +612,9 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.params_min_q_val,
                               self.params_max_q_val, inplace=True)
 
-        self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None
-        device = self.w_scale.device
-
-        if self.preset_act_stats:
-            t = torch.zeros_like(self.w_scale)
-            if self.wts_quant_settings.per_channel:
-                t = t.squeeze(dim=-1)
-            self.register_buffer('accum_scale', t)
-        else:
-            self.accum_scale = torch.ones(1).to(device)
-
         # Quantize bias
         self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None
-        if self.has_bias and not self.preset_act_stats:
+        if self.has_bias and (self.fake_quant_params or not self.preset_act_stats):
             b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias,
                                                                   self.wts_quant_settings.num_bits,
                                                                   self.wts_quant_settings.quant_mode)
@@ -612,8 +622,27 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             self.register_buffer('b_zero_point', b_zero_point)
             base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point,
                                              self.params_min_q_val, self.params_max_q_val)
-            # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor
-            self.register_buffer('base_b_q', base_b_q)
+            if not self.preset_act_stats:
+                # Dynamic ranges - save in auxiliary buffer,
+                # requantize each time based on dynamic input scale factor
+                self.register_buffer('base_b_q', base_b_q)
+
+        # Activations not quantized - de-quant parameters and return
+        if self.fake_quant_params:
+            linear_dequantize(wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True)
+            if self.has_bias:
+                wrapped_module.bias = torch.nn.Parameter(linear_dequantize(base_b_q, self.b_scale, self.b_zero_point))
+            return
+
+        # Activations are quantized - setup accumulator quantization parameters
+        device = self.w_scale.device
+        if self.preset_act_stats:
+            t = torch.zeros_like(self.w_scale)
+            if self.wts_quant_settings.per_channel:
+                t = t.squeeze(dim=-1)
+            self.register_buffer('accum_scale', t)
+        else:
+            self.accum_scale = torch.ones(1).to(device)
 
         # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance.
         # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference,
@@ -623,7 +652,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(0, dtype=torch.uint8, device=device))
 
     def state_dict(self, destination=None, prefix='', keep_vars=False):
-        if self.is_simulated_quant_weight_shifted:
+        if not self.fake_quant_params and self.is_simulated_quant_weight_shifted:
             # We want to return the weights to their integer representation:
             self.wrapped_module.weight.data -= self.w_zero_point
             self.is_simulated_quant_weight_shifted.fill_(False) # i.e. is_simulated_quant_weight_shifted = False
@@ -920,6 +949,11 @@ class FPWrapper(nn.Module):
 
         return result
 
+    def extra_repr(self):
+        tmpstr = 'float_dtype={}, convert_input={}, return_fp32={}'.format(self.dtype, self.convert_input,
+                                                                           self.return_fp32)
+        return tmpstr
+
 
 class FP16Wrapper(FPWrapper):
     def __init__(self, module, convert_input=True, return_fp32=True):
@@ -993,6 +1027,12 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
         return output_scale, output_zero_point
 
+    def extra_repr(self):
+        tmpstr = super(RangeLinearFakeQuantWrapper, self).extra_repr()
+        if self.dtype:
+            tmpstr += '\nwrapped_module_float_dtype={}.'.format(self.dtype)
+        return tmpstr
+
 
 _ptq_wrappers_int_only = (RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper)
 _ptq_wrappers_all = _ptq_wrappers_int_only + (FPWrapper,)
@@ -1104,31 +1144,40 @@ class PostTrainLinearQuantizer(Quantizer):
                                                     'model_activation_stats': model_activation_stats,
                                                     'overrides': overrides_bkp}}
 
-        def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts,
-                                mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits,
-                                clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                input_overrides=None, fpq_module=fpq_module, fake=False):
+        def _check_fp16_arg(fp16, fpq_module):
             if fp16:
                 warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
                               DeprecationWarning)
                 fpq_module = fpq_module or 16
+            return fpq_module
+
+        def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts,
+                                mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits,
+                                clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                input_overrides=None, fpq_module=fpq_module, fake=False):
+            fpq_module = _check_fp16_arg(fp16, fpq_module)
+            if fpq_module and not fake:
+                return FPWrapper(module, fpq_module)
+
             norm_name = distiller.utils.normalize_module_name(name)
+            activation_stats = self.model_activation_stats.get(norm_name, None)
             clip_acts = verify_clip_mode(clip_acts)
-            if fpq_module:
-                if not fake:
-                    return FPWrapper(module, fpq_module)
-                else:
-                    return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
-                                                       activation_stats=self.model_activation_stats.get(norm_name,
-                                                                                                        None),
-                                                       clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                                       scale_approx_mult_bits=scale_approx_mult_bits,
-                                                       fpq_module=fpq_module)
-
-            return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts,
+            qbits = qbits_map[name]
+            if qbits.acts is not None and qbits.wts is None:
+                # Quantizing only activations equals fake-quantization
+                fake = True
+
+            if fake:
+                return RangeLinearFakeQuantWrapper(module, qbits.acts, mode=mode, clip_acts=clip_acts,
+                                                   activation_stats=activation_stats,
+                                                   clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                   scale_approx_mult_bits=scale_approx_mult_bits,
+                                                   fpq_module=fpq_module)
+
+            return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts,
                                                      num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts,
                                                      per_channel_wts=per_channel_wts,
-                                                     activation_stats=self.model_activation_stats.get(norm_name, None),
+                                                     activation_stats=activation_stats,
                                                      clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                      scale_approx_mult_bits=scale_approx_mult_bits,
                                                      input_overrides=input_overrides,
@@ -1139,36 +1188,37 @@ class PostTrainLinearQuantizer(Quantizer):
                                     clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                     input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback,
                                     fpq_module=fpq_module, fake=False):
+            fpq_module = _check_fp16_arg(fp16, fpq_module)
+            if fpq_module and not fake:
+                return FPWrapper(module, fpq_module)
+
             norm_name = distiller.utils.normalize_module_name(name)
+            activation_stats = self.model_activation_stats.get(norm_name, None)
             clip_acts = verify_clip_mode(clip_acts)
-            if fp16:
-                warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
-                              DeprecationWarning)
-                fpq_module = fpq_module or 16
-            if fpq_module:
-                if fake:
-                    return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
-                                                       activation_stats=self.model_activation_stats.get(norm_name, None),
-                                                       clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                                       scale_approx_mult_bits=scale_approx_mult_bits,
-                                                       fpq_module=fpq_module)
-                else:
-                    return FPWrapper(module, fpq_module)
+            qbits = qbits_map[name]
+
+            if fake:
+                return RangeLinearFakeQuantWrapper(module, qbits.acts, mode=mode, clip_acts=clip_acts,
+                                                   activation_stats=activation_stats,
+                                                   clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                   scale_approx_mult_bits=scale_approx_mult_bits,
+                                                   fpq_module=fpq_module)
             try:
-                return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
-                                    activation_stats=self.model_activation_stats.get(norm_name, None),
-                                    clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
+                return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts,
+                                    activation_stats=activation_stats,
+                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                     scale_approx_mult_bits=scale_approx_mult_bits,
                                     input_overrides=input_overrides,
                                     inputs_quant_auto_fallback=inputs_quant_auto_fallback)
             except NoStatsError:
-                msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. '
-                                  'Keeping the original FP32 module'.format(name, module.__class__.__name__))
+                warnings.warn('WARNING: {0} - quantization of {1} without stats not supported. '
+                              'Keeping the original FP32 module'.format(name, module.__class__.__name__), UserWarning)
                 return module
 
-        def replace_embedding(module, name, qbits_map, fp16=fp16):
-            if fp16:
-                return FP16Wrapper(module, convert_input=False)
+        def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module):
+            fpq_module = _check_fp16_arg(fp16, fpq_module)
+            if fpq_module:
+                return FPWrapper(module, fpq_module, convert_input=False)
             norm_name = distiller.utils.normalize_module_name(name)
             return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode,
                                                stats=self.model_activation_stats.get(norm_name, None))
@@ -1182,16 +1232,16 @@ class PostTrainLinearQuantizer(Quantizer):
                 pred = self.adjacency_map[name].predecessors[0].name
                 if isinstance(named_modules[pred], RangeLinearQuantWrapper):
                     return nn.Identity()
-            norm_name = distiller.utils.normalize_module_name(name)
-            clip_acts = verify_clip_mode(clip_acts)
+
             if distiller.has_children(module):
                 return module
-            if fp16:
-                warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
-                              DeprecationWarning)
-                fpq_module = 16
+
+            fpq_module = _check_fp16_arg(fp16, fpq_module)
             if not fake:
                 return FPWrapper(module, fpq_module)
+
+            norm_name = distiller.utils.normalize_module_name(name)
+            clip_acts = verify_clip_mode(clip_acts)
             return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
                                                activation_stats=self.model_activation_stats.get(norm_name, None),
                                                clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
@@ -1271,6 +1321,10 @@ class PostTrainLinearQuantizer(Quantizer):
             return distiller.config_component_from_file_by_class(model, args.qe_config_file,
                                                                  'PostTrainLinearQuantizer')
         else:
+            if args.qe_bits_acts == 0:
+                args.qe_bits_acts = None
+            if args.qe_bits_wts == 0:
+                args.qe_bits_wts = None
             overrides = OrderedDict(
                 [(layer, OrderedDict([('clip_acts', 'NONE')]))
                  for layer in args.qe_no_clip_layers]
@@ -1496,8 +1550,6 @@ class PostTrainLinearQuantizer(Quantizer):
                 buffer.data = buffer.data.to(device)
 
 
-
-
 ###############################################################################
 # Quantization-aware training
 ###############################################################################
diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md
index 061d78f0aa1cddd8b99bb374370dcc124f4f4946..f2eaa4fb37519291c5fb764feb4874abe3d2e017 100644
--- a/examples/quantization/post_train_quant/command_line.md
+++ b/examples/quantization/post_train_quant/command_line.md
@@ -16,8 +16,8 @@ Post-training quantization can either be configured straight from the command-li
 |--------------------------|-----------|---------------------------------------------------------------------------------------|---------|
 | `--quantize-eval`        | `--qe`    | Apply linear quantization to model before evaluation                                  | Off     |
 | `--qe-mode`              | `--qem`   | Linear quantization mode. Choices: "sym", "asym_u", "asym_s"                          | "sym"   |
-| `--qe-bits-acts`         | `--qeba`  | # of bits for quantization of activations                                             | 8       |
-| `--qe-bits-wts`          | `--qebw`  | # of bits for quantization of weights                                                 | 8       |
+| `--qe-bits-acts`         | `--qeba`  | # of bits for quantization of activations. Use 0 to not quantize activations          | 8       |
+| `--qe-bits-wts`          | `--qebw`  | # of bits for quantization of weights. Use 0 to not quantize weights                  | 8       |
 | `--qe-bits-accum`        | N/A       | # of bits for quantization of the accumulator                                         | 32      |
 | `--qe-clip-acts`         | `--qeca`  | Set activations clipping mode. Choices: "none", "avg", "n_std"                        | "none"  |
 | `--qe-clip-n-stds`       | N/A       | When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use | None    |
diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index a30524f53604889abc34fa7a9be793ce81b60fa8..d54ea330a14c8b898e83819d663074fcee0dcdd4 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -743,3 +743,64 @@ def test_acts_quant_params_rnn(rnn_model):
     quantizer.update_acts_quant_params(new_config)
     assert model.rnn.rnn.cells[0].act_o.output_scale == 4
     assert model.embedding.w_scale == 59.0
+
+
+###############################################################################
+# Test wrappers with weights-only quantization
+###############################################################################
+@pytest.fixture(params=[False, True], ids=['perch_off', 'perch_on'])
+def per_channel(request):
+    return request.param
+
+
+@pytest.fixture(params=[False, True], ids=['no_bias', 'with_bias'])
+def bias(request):
+    return request.param
+
+
+def _fake_quant_tensor(tensor, n_bits, mode, per_channel):
+    q_min, q_max = q_utils.get_quantized_range(n_bits, mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
+    scale, zp = _get_quant_params_from_tensor(tensor, n_bits, mode, per_channel=per_channel)
+    q_utils.linear_quantize_clamp(tensor, scale, zp, q_min, q_max, inplace=True)
+    q_utils.linear_dequantize(tensor, scale, zp, inplace=True)
+
+
+def _test_wts_only_quant(layer, x, per_channel, bias, num_bits):
+    layer.weight.data = torch.rand_like(layer.weight)
+    if bias:
+        layer.bias.data = torch.rand_like(layer.bias)
+    mode = LinearQuantMode.ASYMMETRIC_UNSIGNED
+
+    layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits, mode=mode, per_channel_wts=per_channel)
+    layer_ptq.eval()
+
+    layer_manual_q = deepcopy(layer)
+    _fake_quant_tensor(layer_manual_q.weight.data, num_bits, mode, per_channel)
+    assert torch.equal(layer_ptq.wrapped_module.weight, layer_manual_q.weight)
+    if bias:
+        _fake_quant_tensor(layer_manual_q.bias.data, num_bits, mode, False)
+        assert torch.equal(layer_ptq.wrapped_module.bias, layer_manual_q.bias)
+
+    y_ptq = layer_ptq(x)
+    y_manual_q = layer_manual_q(x)
+
+    assert torch.equal(y_ptq, y_manual_q)
+
+
+def test_conv_layer_wrapper_params_only(per_channel, bias):
+    distiller.set_deterministic()
+    in_ch = 3
+    layer = torch.nn.Conv2d(in_ch, 10, 3, bias=bias)
+    x = torch.rand(5, in_ch, 5, 5)
+
+    _test_wts_only_quant(layer, x, per_channel, bias, 8)
+
+
+def test_linear_layer_wrapper_params_only(per_channel, bias):
+    distiller.set_deterministic()
+    in_features = 50
+    layer = torch.nn.Linear(in_features, 30, bias=bias)
+
+    x = torch.rand(5, in_features)
+
+    _test_wts_only_quant(layer, x, per_channel, bias)