Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
9bc69fef
Unverified
Commit
9bc69fef
authored
5 years ago
by
Guy Jacob
Committed by
GitHub
5 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Faster and more memory-efficient impl. of Learned-clipped quant (#336)
parent
cb02798a
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
distiller/quantization/clipped_linear.py
+12
-31
12 additions, 31 deletions
distiller/quantization/clipped_linear.py
with
12 additions
and
31 deletions
distiller/quantization/clipped_linear.py
+
12
−
31
View file @
9bc69fef
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.quantizer
import
Quantizer
from
.quantizer
import
Quantizer
from
.q_utils
import
*
from
.q_utils
import
*
...
@@ -27,34 +28,6 @@ msglogger = logging.getLogger()
...
@@ -27,34 +28,6 @@ msglogger = logging.getLogger()
###
###
class
LearnedClippedLinearQuantizeSTE
(
torch
.
autograd
.
Function
):
@staticmethod
def
forward
(
ctx
,
input
,
clip_val
,
num_bits
,
dequantize
,
inplace
):
ctx
.
save_for_backward
(
input
,
clip_val
)
if
inplace
:
ctx
.
mark_dirty
(
input
)
scale
,
zero_point
=
asymmetric_linear_quantization_params
(
num_bits
,
0
,
clip_val
.
data
[
0
],
signed
=
False
)
output
=
clamp
(
input
,
0
,
clip_val
.
data
[
0
],
inplace
)
output
=
linear_quantize
(
output
,
scale
,
zero_point
,
inplace
)
if
dequantize
:
output
=
linear_dequantize
(
output
,
scale
,
zero_point
,
inplace
)
return
output
@staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
clip_val
=
ctx
.
saved_tensors
grad_input
=
grad_output
.
clone
()
grad_input
[
input
.
le
(
0
)]
=
0
grad_input
[
input
.
ge
(
clip_val
.
data
[
0
])]
=
0
grad_alpha
=
grad_output
.
clone
()
grad_alpha
[
input
.
lt
(
clip_val
.
data
[
0
])]
=
0
grad_alpha
=
grad_alpha
.
sum
().
expand_as
(
clip_val
)
# Straight-through estimator for the scale factor calculation
return
grad_input
,
grad_alpha
,
None
,
None
,
None
class
ClippedLinearQuantization
(
nn
.
Module
):
class
ClippedLinearQuantization
(
nn
.
Module
):
def
__init__
(
self
,
num_bits
,
clip_val
,
dequantize
=
True
,
inplace
=
False
):
def
__init__
(
self
,
num_bits
,
clip_val
,
dequantize
=
True
,
inplace
=
False
):
super
(
ClippedLinearQuantization
,
self
).
__init__
()
super
(
ClippedLinearQuantization
,
self
).
__init__
()
...
@@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module):
...
@@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module):
self
.
inplace
=
inplace
self
.
inplace
=
inplace
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
input
=
LearnedClippedLinearQuantizeSTE
.
apply
(
input
,
self
.
clip_val
,
self
.
num_bits
,
# Clip between 0 to the learned clip_val
self
.
dequantize
,
self
.
inplace
)
input
=
F
.
relu
(
input
,
self
.
inplace
)
# Using the 'where' operation as follows gives us the correct gradient with respect to clip_val
input
=
torch
.
where
(
input
<
self
.
clip_val
,
input
,
self
.
clip_val
)
with
torch
.
no_grad
():
scale
,
zero_point
=
asymmetric_linear_quantization_params
(
self
.
num_bits
,
0
,
self
.
clip_val
,
signed
=
False
)
input
=
LinearQuantizeSTE
.
apply
(
input
,
scale
,
zero_point
,
self
.
dequantize
,
self
.
inplace
)
return
input
return
input
def
__repr__
(
self
):
def
__repr__
(
self
):
inplace_str
=
'
, inplace
'
if
self
.
inplace
else
''
inplace_str
=
'
, inplace
'
if
self
.
inplace
else
''
return
'
{0}(num_bits={1}, clip_val={2}{3})
'
.
format
(
self
.
__class__
.
__name__
,
self
.
num_bits
,
self
.
clip_val
,
return
'
{0}(num_bits={1}, clip_val={2}{3})
'
.
format
(
self
.
__class__
.
__name__
,
self
.
num_bits
,
self
.
clip_val
.
item
()
,
inplace_str
)
inplace_str
)
...
@@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer):
...
@@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer):
self
.
replacement_factory
[
nn
.
ReLU
]
=
relu_replace_fn
self
.
replacement_factory
[
nn
.
ReLU
]
=
relu_replace_fn
def
dorefa_quantize_param
(
param_fp
,
param_meta
):
def
dorefa_quantize_param
(
param_fp
,
param_meta
):
if
param_meta
.
num_bits
==
1
:
if
param_meta
.
num_bits
==
1
:
out
=
DorefaParamsBinarizationSTE
.
apply
(
param_fp
)
out
=
DorefaParamsBinarizationSTE
.
apply
(
param_fp
)
...
@@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta):
...
@@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta):
out
=
2
*
out
-
1
out
=
2
*
out
-
1
return
out
return
out
class
DorefaParamsBinarizationSTE
(
torch
.
autograd
.
Function
):
class
DorefaParamsBinarizationSTE
(
torch
.
autograd
.
Function
):
@staticmethod
@staticmethod
def
forward
(
ctx
,
input
,
inplace
=
False
):
def
forward
(
ctx
,
input
,
inplace
=
False
):
...
@@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
...
@@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
grad_output
,
None
return
grad_output
,
None
class
DorefaQuantizer
(
Quantizer
):
class
DorefaQuantizer
(
Quantizer
):
"""
"""
Quantizer using the DoReFa scheme, as defined in:
Quantizer using the DoReFa scheme, as defined in:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment