diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 2badfdc6a8b7288ecf4b8bd6a381bbe077c3dfcf..58891e8d4b8f3f4a84173a89041e49c20d755012 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -347,9 +347,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.params_max_q_val, inplace=True) self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None + device = self.w_scale.device if self.preset_act_stats: - self.register_buffer('accum_scale', self.in_0_scale * self.w_scale) + self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale) if self.per_channel_wts: self.accum_scale = self.accum_scale.squeeze(dim=-1) else: @@ -570,6 +571,34 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): return output_scale / self.accum_scale, output_zero_point +class FP16Wrapper(nn.Module): + """ + A wrapper that replaces a module with a half precision version. + + Args: + module (nn.Module): The module to be replaced. + convert_input (:obj:`bool`, optional): Specifies whether an input conversion + to fp16 is required for forward. Default: True. + return_fp32 (:obj:`bool`, optional): Specifies whether the output needs + to be converted back to fp32. Default: True. + """ + def __init__(self, module: nn.Module, convert_input=True, return_fp32=True): + super(FP16Wrapper, self).__init__() + self.wrapped_module = module.half() + self.return_fp32 = return_fp32 + self.convert_input_fp16 = convert_input + + def forward(self, *input): + if self.convert_input_fp16: + input = distiller.convert_tensors_recursively_to(input, dtype=torch.float16) + + result = self.wrapped_module(*input) + if self.return_fp32: + return distiller.convert_tensors_recursively_to(result, dtype=torch.float32) + + return result + + class RangeLinearEmbeddingWrapper(nn.Module): def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None): if not isinstance(wrapped_module, nn.Embedding): @@ -585,8 +614,10 @@ class RangeLinearEmbeddingWrapper(nn.Module): else: w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode) - self.register_buffer('w_scale', w_scale) - self.register_buffer('w_zero_point', w_zero_point) + device = wrapped_module.weight.device + + self.register_buffer('w_scale', w_scale.to(device)) + self.register_buffer('w_zero_point', w_zero_point.to(device)) linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val, inplace=True) @@ -607,6 +638,7 @@ class PostTrainLinearQuantizer(Quantizer): Args: model (torch.nn.Module): Model to be quantized bits_activations/parameters/accum (int): Number of bits to be used when quantizing each tensor type + overrides (:obj:`OrderedDict`, optional): Overrides the layers quantization settings. clip_acts (bool): See RangeLinearQuantWrapper no_clip_layers (list): List of fully-qualified layer names for which activations clipping should not be done. A common practice is to not clip the activations of the last layer before softmax. @@ -614,10 +646,14 @@ class PostTrainLinearQuantizer(Quantizer): per_channel_wts (bool): Set to True to enable per-channel quantization of weights (per output channel) model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary containing the stats. If None then stats will be calculated dynamically. + fp16 (bool): Set to True to convert modules to half precision. + Note: + If fp16 is set to True, all the layers (except those overriden in `overrides`) will be converted + to half precision, regardless of bits_activations/parameters/accum. """ def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=False, no_clip_layers=None, - per_channel_wts=False, model_activation_stats=None): + per_channel_wts=False, model_activation_stats=None, fp16=False): super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_parameters, bits_bias=bits_accum, overrides=overrides, train_with_fp_copy=False) @@ -641,29 +677,37 @@ class PostTrainLinearQuantizer(Quantizer): 'mode': str(mode).split('.')[1], 'clip_acts': clip_acts, 'no_clip_layers': no_clip_layers, 'per_channel_wts': per_channel_wts}} - - def replace_param_layer(module, name, qbits_map): + + def replace_param_layer(module, name, qbits_map, + per_channel_wts=per_channel_wts, + mode=mode, + fp16=fp16): + if fp16: + return FP16Wrapper(module) norm_name = distiller.utils.normalize_module_name(name) - clip = clip_acts and norm_name not in self.no_clip_layers + clip = self.clip_acts and norm_name not in self.no_clip_layers return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip, per_channel_wts=per_channel_wts, activation_stats=self.model_activation_stats.get(norm_name, None)) - def replace_non_param_layer(wrapper_type, module, name, qbits_map): + def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16): + if fp16: + return FP16Wrapper(module) norm_name = distiller.utils.normalize_module_name(name) clip = self.clip_acts and norm_name not in self.no_clip_layers return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip, activation_stats=self.model_activation_stats.get(norm_name, None)) - def replace_embedding(module, name, qbits_map): + def replace_embedding(module, name, qbits_map, fp16=fp16): + if fp16: + return FP16Wrapper(module, convert_input=False) norm_name = distiller.utils.normalize_module_name(name) return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode, stats=self.model_activation_stats.get(norm_name, None)) - self.clip_acts = clip_acts - self.no_clip_layers = [] or no_clip_layers + self.no_clip_layers = no_clip_layers or [] self.model_activation_stats = model_activation_stats or {} self.bits_accum = bits_accum self.mode = mode