Skip to content
Snippets Groups Projects
Unverified Commit e82d9380 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Post-train quant: Refactor inputs quantization (#454)

* Fake quant wrapper now also works on (fake) quantized inputs
* Remove 'requires_quantized_inputs' flag
* Unrelated: Moved LinearQuantMode enum to q_utils
parent 47175961
No related branches found
No related tags found
No related merge requests found
......@@ -16,9 +16,11 @@
from .quantizer import Quantizer
from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \
LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode
QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode, \
RangeLinearEmbeddingWrapper, RangeLinearFakeQuantWrapper, RangeLinearQuantMatmulWrapper
from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
from .q_utils import *
del quantizer
del range_linear
......
......@@ -18,7 +18,8 @@ Here we implement the greedy search algorithm for automatic quantization.
"""
import torch
import torch.nn as nn
from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, LinearQuantMode
from distiller.quantization import LinearQuantMode
from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode
from distiller.summary_graph import SummaryGraph
from distiller.model_transforms import fold_batch_norms
import distiller.modules
......
......@@ -18,6 +18,12 @@ from enum import Enum
import torch
class LinearQuantMode(Enum):
SYMMETRIC = 1
ASYMMETRIC_UNSIGNED = 2
ASYMMETRIC_SIGNED = 3
def _prep_saturation_val_tensor(sat_val):
is_scalar = not isinstance(sat_val, torch.Tensor)
out = torch.tensor(sat_val) if is_scalar else sat_val.clone().detach()
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment