Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
ce3528e4
Unverified
Commit
ce3528e4
authored
5 years ago
by
Guy Jacob
Committed by
GitHub
5 years ago
Browse files
Options
Downloads
Patches
Plain Diff
[Quantizer] Fix handling when default bits_activations == None (#345)
parent
e65ec8fc
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
distiller/quantization/quantizer.py
+22
-22
22 additions, 22 deletions
distiller/quantization/quantizer.py
tests/test_quantizer.py
+32
-25
32 additions, 25 deletions
tests/test_quantizer.py
with
54 additions
and
47 deletions
distiller/quantization/quantizer.py
+
22
−
22
View file @
ce3528e4
...
...
@@ -281,28 +281,28 @@ class Quantizer(object):
# We indicate this module wasn't replaced by a wrapper
replace_msg
(
full_name
)
self
.
modules_processed
[
module
]
=
full_name
,
None
continue
# We use a type hint comment to let IDEs know replace_fn is a function
replace_fn
=
self
.
replacement_factory
[
type
(
module
)]
# type: Optional[Callab
le
]
# If the replacement function wasn't specified - continue without replacing this module.
if
replace_fn
is
not
None
:
valid_kwargs
,
invalid_kwargs
=
distiller
.
filter_kwargs
(
self
.
module_overrides_map
[
full_name
],
replace_fn
)
if
invalid_kwargs
:
raise
TypeError
(
"""
Quantizer of type %s doesn
'
t accept
\"
%s
\"
as override arguments for %s. Allowed kwargs: %s
"""
%
(
type
(
self
),
list
(
invalid_kwargs
),
type
(
module
),
list
(
valid_kwargs
)))
new_module
=
replace_fn
(
module
,
full_name
,
self
.
module_qbits_map
,
**
valid_kwargs
)
replace_msg
(
full_name
,
(
module
,
new_module
))
# Add to history of prepared submodules
self
.
modules_processed
[
module
]
=
full_name
,
new_module
setattr
(
container
,
name
,
new_module
)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
if
not
distiller
.
has_children
(
module
)
and
distiller
.
has_children
(
new_module
):
for
sub_module_name
,
sub_module
in
new_module
.
named_modules
():
self
.
_add_qbits_entry
(
full_name
+
'
.
'
+
sub_module_name
,
type
(
sub_module
),
current_qbits
)
self
.
module_qbits_map
[
full_name
]
=
QBits
(
acts
=
current_qbits
.
acts
,
wts
=
None
,
bias
=
None
)
else
:
# We use a type hint comment to let IDEs know replace_fn is a function
replace_fn
=
self
.
replacement_factory
[
type
(
module
)]
# type: Optional[Callable]
# If the replacement function wasn't specified - continue without replacing this modu
le
.
if
replace_fn
is
not
None
:
valid_kwargs
,
invalid_kwargs
=
distiller
.
filter_kwargs
(
self
.
module_overrides_map
[
full_name
],
replace_fn
)
if
invalid_kwargs
:
raise
TypeError
(
"""
Quantizer of type %s doesn
'
t accept
\"
%s
\"
as override arguments for %s. Allowed kwargs: %s
"""
%
(
type
(
self
),
list
(
invalid_kwargs
),
type
(
module
),
list
(
valid_kwargs
)))
new_module
=
replace_fn
(
module
,
full_name
,
self
.
module_qbits_map
,
**
valid_kwargs
)
replace_msg
(
full_name
,
(
module
,
new_module
))
# Add to history of prepared submodules
self
.
modules_processed
[
module
]
=
full_name
,
new_module
setattr
(
container
,
name
,
new_module
)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
if
not
distiller
.
has_children
(
module
)
and
distiller
.
has_children
(
new_module
):
for
sub_module_name
,
sub_module
in
new_module
.
named_modules
():
self
.
_add_qbits_entry
(
full_name
+
'
.
'
+
sub_module_name
,
type
(
sub_module
),
current_qbits
)
self
.
module_qbits_map
[
full_name
]
=
QBits
(
acts
=
current_qbits
.
acts
,
wts
=
None
,
bias
=
None
)
if
distiller
.
has_children
(
module
):
# For container we call recursively
...
...
This diff is collapsed.
Click to expand it.
tests/test_quantizer.py
+
32
−
25
View file @
ce3528e4
...
...
@@ -147,30 +147,34 @@ class DummyQuantizer(Quantizer):
#############################
# Other utils
#############################
expected_type_replacements
=
{
nn
.
Conv2d
:
DummyWrapperLayer
,
nn
.
ReLU
:
DummyQuantLayer
,
nn
.
Linear
:
DummyWrapperLayer
}
def
params_quantizable
(
module
):
return
isinstance
(
module
,
(
nn
.
Conv2d
,
nn
.
Linear
))
def
get_expected_qbits
(
model
,
qbits
,
expected_overrides
):
expected_qbits
=
{}
post_prepare_changes
=
{}
expected_type_replacements
=
{
nn
.
Conv2d
:
DummyWrapperLayer
,
nn
.
ReLU
:
DummyQuantLayer
,
nn
.
Linear
:
DummyWrapperLayer
}
expected_qbits
=
OrderedDict
()
post_prepare_qbbits_changes
=
OrderedDict
()
post_prepare_expected_types
=
OrderedDict
()
prefix
=
'
module.
'
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
)
else
''
for
orig_name
,
orig_module
in
model
.
named_modules
():
orig_module_type
=
type
(
orig_module
)
bits_a
,
bits_w
,
bits_b
=
expected_overrides
.
get
(
orig_name
.
replace
(
prefix
,
''
,
1
),
qbits
)
if
not
params_quantizable
(
orig_module
):
bits_w
=
bits_b
=
None
expected_qbits
[
orig_name
]
=
QBits
(
bits_a
,
bits_w
,
bits_b
)
if
expected_qbits
[
orig_name
]
==
QBits
(
None
,
None
,
None
):
post_prepare_expected_types
[
orig_name
]
=
orig_module_type
else
:
post_prepare_expected_types
[
orig_name
]
=
expected_type_replacements
.
get
(
orig_module_type
,
orig_module_type
)
# We're testing replacement of module with container
if
post_prepare_expected_types
[
orig_name
]
==
DummyWrapperLayer
:
post_prepare_qbbits_changes
[
orig_name
]
=
QBits
(
bits_a
,
None
,
None
)
post_prepare_qbbits_changes
[
orig_name
+
'
.inner
'
]
=
expected_qbits
[
orig_name
]
post_prepare_expected_types
[
orig_name
+
'
.inner
'
]
=
orig_module_type
# We're testing replacement of module with container
if
isinstance
(
orig_module
,
(
nn
.
Conv2d
,
nn
.
Linear
)):
post_prepare_changes
[
orig_name
]
=
QBits
(
bits_a
,
None
,
None
)
post_prepare_changes
[
orig_name
+
'
.inner
'
]
=
expected_qbits
[
orig_name
]
return
expected_qbits
,
post_prepare_changes
return
expected_qbits
,
post_prepare_qbbits_changes
,
post_prepare_expected_types
#############################
...
...
@@ -251,7 +255,10 @@ bias_key = 'bits_bias'
'
sub1.relu2
'
:
QBits
(
8
,
None
,
None
),
'
sub1.pool2
'
:
QBits
(
8
,
None
,
None
)}),
(
QBits
(
8
,
4
,
32
),
OrderedDict
([(
'
conv1
'
,
{
acts_key
:
8
,
wts_key
:
4
,
bias_key
:
None
})]),
{
'
conv1
'
:
QBits
(
8
,
4
,
None
)})
{
'
conv1
'
:
QBits
(
8
,
4
,
None
)}),
(
QBits
(
None
,
8
,
32
),
OrderedDict
([(
'
conv1
'
,
{
acts_key
:
8
,
wts_key
:
8
,
bias_key
:
32
})]),
{
'
conv1
'
:
QBits
(
8
,
8
,
32
)})
],
ids
=
[
'
no_override
'
,
...
...
@@ -260,7 +267,8 @@ bias_key = 'bits_bias'
'
overlap_pattern_override_proper
'
,
# "proper" ==> Specific pattern before broader pattern
'
overlap_pattern_override_wrong
'
,
# "wrong" ==> Broad pattern before specific pattern, so specific pattern
# never actually matched
'
wts_quant_bias_not
'
'
wts_quant_bias_not
'
,
'
dont_quant_acts
'
]
)
def
test_model_prep
(
model
,
optimizer
,
qbits
,
overrides
,
explicit_expected_overrides
,
...
...
@@ -270,7 +278,9 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
m_orig
=
deepcopy
(
model
)
# Build expected QBits
expected_qbits
,
post_prepare_changes
=
get_expected_qbits
(
model
,
qbits
,
explicit_expected_overrides
)
expected_qbits
,
post_prepare_changes
,
post_prepare_expected_types
=
get_expected_qbits
(
model
,
qbits
,
explicit_expected_overrides
)
# Initialize Quantizer
q
=
DummyQuantizer
(
model
,
optimizer
=
optimizer
,
...
...
@@ -317,15 +327,12 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
# Check module replacement is as expected
q_module
=
q_named_modules
[
orig_name
]
expected_type
=
expected_type_replacements
.
get
(
type
(
orig_module
))
if
expected_type
is
None
or
expected_qbits
[
orig_name
]
==
QBits
(
None
,
None
,
None
):
assert
type
(
orig_module
)
==
type
(
q_module
)
else
:
assert
type
(
q_module
)
==
expected_type
if
expected_type
==
DummyWrapperLayer
:
assert
expected_qbits
[
orig_name
+
'
.inner
'
]
==
q_module
.
qbits
else
:
assert
expected_qbits
[
orig_name
]
==
q_module
.
qbits
expected_type
=
post_prepare_expected_types
[
orig_name
]
assert
type
(
q_module
)
==
expected_type
if
expected_type
==
DummyWrapperLayer
:
assert
expected_qbits
[
orig_name
+
'
.inner
'
]
==
q_module
.
qbits
elif
expected_type
==
DummyQuantLayer
:
assert
expected_qbits
[
orig_name
]
==
q_module
.
qbits
@pytest.mark.parametrize
(
...
...
@@ -344,7 +351,7 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
def
test_param_quantization
(
model
,
optimizer
,
qbits
,
overrides
,
explicit_expected_overrides
,
train_with_fp_copy
):
# Build expected QBits
expected_qbits
,
post_prepare_changes
=
get_expected_qbits
(
model
,
qbits
,
explicit_expected_overrides
)
expected_qbits
,
post_prepare_changes
,
_
=
get_expected_qbits
(
model
,
qbits
,
explicit_expected_overrides
)
q
=
DummyQuantizer
(
model
,
optimizer
=
optimizer
,
bits_activations
=
qbits
.
acts
,
bits_weights
=
qbits
.
wts
,
bits_bias
=
qbits
.
bias
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment