Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
a69dd5d6
Commit
a69dd5d6
authored
6 years ago
by
Lev Zlotnik
Committed by
Guy Jacob
6 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Quantizer: Proper handling of modules that point to same object (#239)
parent
343e9a82
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
distiller/quantization/quantizer.py
+33
-0
33 additions, 0 deletions
distiller/quantization/quantizer.py
tests/test_quantizer.py
+43
-2
43 additions, 2 deletions
tests/test_quantizer.py
with
76 additions
and
2 deletions
distiller/quantization/quantizer.py
+
33
−
0
View file @
a69dd5d6
...
@@ -173,6 +173,9 @@ class Quantizer(object):
...
@@ -173,6 +173,9 @@ class Quantizer(object):
self
.
train_with_fp_copy
=
train_with_fp_copy
self
.
train_with_fp_copy
=
train_with_fp_copy
self
.
params_to_quantize
=
[]
self
.
params_to_quantize
=
[]
# A dictionary of replaced modules and their respective names.
self
.
modules_replaced
=
OrderedDict
()
def
_add_qbits_entry
(
self
,
module_name
,
module_type
,
qbits
):
def
_add_qbits_entry
(
self
,
module_name
,
module_type
,
qbits
):
if
module_type
not
in
[
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
]:
if
module_type
not
in
[
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
]:
# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
...
@@ -184,6 +187,26 @@ class Quantizer(object):
...
@@ -184,6 +187,26 @@ class Quantizer(object):
self
.
module_overrides_map
[
module_name
]
=
entry
self
.
module_overrides_map
[
module_name
]
=
entry
def
prepare_model
(
self
):
def
prepare_model
(
self
):
"""
Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
and overrides configuration provided to __init__(), and according to the replacement_factory as
defined by the Quantizer sub-class being used.
Note:
If multiple sub-modules within the model actually reference the same module, then that module
is replaced only once, according to the configuration (bit-width and/or overrides) of the
first encountered reference.
Toy Example - say a module is constructed using this bit of code:
shared_relu = nn.ReLU
self.relu1 = shared_relu
self.relu2 = shared_relu
When traversing the model, a replacement will be generated when
'
self.relu1
'
is encountered.
Let
'
s call it `new_relu1
'
. When
'
self.relu2
'
will be encountered, it
'
ll simply be replaced
with a reference to
'
new_relu1
'
. Any override configuration made specifically for
'
self.relu2
'
will be ignored. A warning message will be shown.
"""
self
.
_prepare_model_impl
()
self
.
_prepare_model_impl
()
msglogger
.
info
(
'
Quantized model:
\n\n
{0}
\n
'
.
format
(
self
.
model
))
msglogger
.
info
(
'
Quantized model:
\n\n
{0}
\n
'
.
format
(
self
.
model
))
...
@@ -226,6 +249,14 @@ class Quantizer(object):
...
@@ -226,6 +249,14 @@ class Quantizer(object):
# Iterate through model, insert quantization functions as appropriate
# Iterate through model, insert quantization functions as appropriate
for
name
,
module
in
container
.
named_children
():
for
name
,
module
in
container
.
named_children
():
full_name
=
prefix
+
name
full_name
=
prefix
+
name
if
module
in
self
.
modules_replaced
:
previous_name
,
previous_wrapper
=
self
.
modules_replaced
[
module
]
warnings
.
warn
(
"
Module
'
{0}
'
references to same module as
'
{1}
'
.
"
'
Replacing with reference the same wrapper.
'
.
format
(
full_name
,
previous_name
),
UserWarning
)
msglogger
.
debug
(
'
Module {0}: Replacing
\n
{1} with
\n
{2}
'
.
format
(
full_name
,
module
,
previous_wrapper
))
setattr
(
container
,
name
,
previous_wrapper
)
continue
current_qbits
=
self
.
module_qbits_map
[
full_name
]
current_qbits
=
self
.
module_qbits_map
[
full_name
]
if
current_qbits
.
acts
is
None
and
current_qbits
.
wts
is
None
:
if
current_qbits
.
acts
is
None
and
current_qbits
.
wts
is
None
:
if
self
.
module_overrides_map
[
full_name
]:
if
self
.
module_overrides_map
[
full_name
]:
...
@@ -241,6 +272,8 @@ class Quantizer(object):
...
@@ -241,6 +272,8 @@ class Quantizer(object):
new_module
=
self
.
replacement_factory
[
type
(
module
)](
module
,
full_name
,
new_module
=
self
.
replacement_factory
[
type
(
module
)](
module
,
full_name
,
self
.
module_qbits_map
,
**
valid_kwargs
)
self
.
module_qbits_map
,
**
valid_kwargs
)
msglogger
.
debug
(
'
Module {0}: Replacing
\n
{1} with
\n
{2}
'
.
format
(
full_name
,
module
,
new_module
))
msglogger
.
debug
(
'
Module {0}: Replacing
\n
{1} with
\n
{2}
'
.
format
(
full_name
,
module
,
new_module
))
# Add to history of prepared submodules
self
.
modules_replaced
[
module
]
=
full_name
,
new_module
setattr
(
container
,
name
,
new_module
)
setattr
(
container
,
name
,
new_module
)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
...
...
This diff is collapsed.
Click to expand it.
tests/test_quantizer.py
+
43
−
2
View file @
a69dd5d6
...
@@ -83,6 +83,33 @@ class DummyModel(nn.Sequential):
...
@@ -83,6 +83,33 @@ class DummyModel(nn.Sequential):
p
.
data
=
torch
.
zeros_like
(
p
)
p
.
data
=
torch
.
zeros_like
(
p
)
class
DummyDenseWithRelu
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
output_size
,
relu
=
None
):
super
(
DummyDenseWithRelu
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
relu
=
relu
or
nn
.
ReLU
()
self
.
linear
=
nn
.
Linear
(
input_size
,
output_size
)
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
linear
(
x
))
class
DummyModelWithSharedSubmodule
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
hidden_size
,
output_size
):
super
(
DummyModelWithSharedSubmodule
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
output_size
=
output_size
self
.
dense1
=
DummyDenseWithRelu
(
input_size
,
hidden_size
)
self
.
dense2
=
DummyDenseWithRelu
(
hidden_size
,
output_size
,
self
.
dense1
.
relu
)
def
forward
(
self
,
x
):
x
=
self
.
dense1
(
x
)
x
=
self
.
dense2
(
x
)
return
x
#############################
#############################
# Dummy Quantizer
# Dummy Quantizer
#############################
#############################
...
@@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer):
...
@@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer):
self
.
replacement_factory
[
nn
.
Conv2d
]
=
_dummy_wrapper_layer
self
.
replacement_factory
[
nn
.
Conv2d
]
=
_dummy_wrapper_layer
self
.
replacement_factory
[
nn
.
ReLU
]
=
_dummy_quant_layer
self
.
replacement_factory
[
nn
.
ReLU
]
=
_dummy_quant_layer
self
.
replacement_factory
[
nn
.
Linear
]
=
_dummy_wrapper_layer
self
.
param_quantization_fn
=
dummy_quantize_params
self
.
param_quantization_fn
=
dummy_quantize_params
...
@@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer):
...
@@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer):
# Other utils
# Other utils
#############################
#############################
expected_type_replacements
=
{
nn
.
Conv2d
:
DummyWrapperLayer
,
nn
.
ReLU
:
DummyQuantLayer
}
expected_type_replacements
=
{
nn
.
Conv2d
:
DummyWrapperLayer
,
nn
.
ReLU
:
DummyQuantLayer
,
nn
.
Linear
:
DummyWrapperLayer
}
def
params_quantizable
(
module
):
def
params_quantizable
(
module
):
...
@@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides):
...
@@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides):
expected_qbits
[
orig_name
]
=
QBits
(
bits_a
,
bits_w
,
bits_b
)
expected_qbits
[
orig_name
]
=
QBits
(
bits_a
,
bits_w
,
bits_b
)
# We're testing replacement of module with container
# We're testing replacement of module with container
if
isinstance
(
orig_module
,
nn
.
Conv2d
):
if
isinstance
(
orig_module
,
(
nn
.
Conv2d
,
nn
.
Linear
)
):
post_prepare_changes
[
orig_name
]
=
QBits
(
bits_a
,
None
,
None
)
post_prepare_changes
[
orig_name
]
=
QBits
(
bits_a
,
None
,
None
)
post_prepare_changes
[
orig_name
+
'
.inner
'
]
=
expected_qbits
[
orig_name
]
post_prepare_changes
[
orig_name
+
'
.inner
'
]
=
expected_qbits
[
orig_name
]
...
@@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy):
...
@@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy):
q
=
DummyQuantizer
(
model_copy
,
optimizer
=
optimizer
,
overrides
=
overrides
,
train_with_fp_copy
=
train_with_fp_copy
)
q
=
DummyQuantizer
(
model_copy
,
optimizer
=
optimizer
,
overrides
=
overrides
,
train_with_fp_copy
=
train_with_fp_copy
)
q
.
prepare_model
()
q
.
prepare_model
()
assert
model_copy
.
relu1
.
overridable_prop
==
123
assert
model_copy
.
relu1
.
overridable_prop
==
123
def
test_shared_submodule
(
optimizer
,
train_with_fp_copy
):
with
pytest
.
warns
(
UserWarning
,
match
=
"
Module
'
{0}
'
references to same module as
'
{1}
'
.
"
.
format
(
'
dense2.relu
'
,
'
dense1.relu
'
)):
densenet
=
DummyModelWithSharedSubmodule
(
1024
,
1024
,
1000
)
quantizer
=
DummyQuantizer
(
densenet
,
bits_weights
=
8
,
bits_activations
=
8
,
bits_bias
=
32
,
optimizer
=
optimizer
,
train_with_fp_copy
=
train_with_fp_copy
)
quantizer
.
prepare_model
()
assert
isinstance
(
quantizer
.
model
.
dense1
.
relu
,
DummyQuantLayer
)
assert
quantizer
.
model
.
dense1
.
relu
==
quantizer
.
model
.
dense2
.
relu
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment