Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
99124355
Commit
99124355
authored
5 years ago
by
Neta Zmora
Browse files
Options
Downloads
Patches
Plain Diff
checkpoint.py: non-functional code refactoring
Rearranged the code for easier reading and maintenance
parent
bdafebea
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
distiller/apputils/checkpoint.py
+61
-54
61 additions, 54 deletions
distiller/apputils/checkpoint.py
with
61 additions
and
54 deletions
distiller/apputils/checkpoint.py
+
61
−
54
View file @
99124355
...
@@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
...
@@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
checkpoint
[
'
quantizer_metadata
'
]
=
model
.
quantizer_metadata
checkpoint
[
'
quantizer_metadata
'
]
=
model
.
quantizer_metadata
checkpoint
[
'
extras
'
]
=
extras
checkpoint
[
'
extras
'
]
=
extras
torch
.
save
(
checkpoint
,
fullpath
)
torch
.
save
(
checkpoint
,
fullpath
)
if
is_best
:
if
is_best
:
shutil
.
copyfile
(
fullpath
,
fullpath_best
)
shutil
.
copyfile
(
fullpath
,
fullpath_best
)
...
@@ -101,8 +100,8 @@ def get_contents_table(d):
...
@@ -101,8 +100,8 @@ def get_contents_table(d):
return
tabulate
(
contents
,
headers
=
[
"
Key
"
,
"
Type
"
,
"
Value
"
],
tablefmt
=
"
psql
"
)
return
tabulate
(
contents
,
headers
=
[
"
Key
"
,
"
Type
"
,
"
Value
"
],
tablefmt
=
"
psql
"
)
def
load_checkpoint
(
model
,
chkpt_file
,
optimizer
=
None
,
model_device
=
None
,
*
,
def
load_checkpoint
(
model
,
chkpt_file
,
optimizer
=
None
,
lean_checkpoint
=
False
,
strict
=
False
):
model_device
=
None
,
lean_checkpoint
=
False
,
strict
=
False
):
"""
Load a pytorch training checkpoint.
"""
Load a pytorch training checkpoint.
Args:
Args:
...
@@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
...
@@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
This should be set to either
'
cpu
'
or
'
cuda
'
.
This should be set to either
'
cpu
'
or
'
cuda
'
.
:returns: updated model, compression_scheduler, optimizer, start_epoch
:returns: updated model, compression_scheduler, optimizer, start_epoch
"""
"""
def
_load_compression_scheduler
():
normalize_keys
=
False
try
:
compression_scheduler
.
load_state_dict
(
checkpoint
[
'
compression_sched
'
],
normalize_keys
)
except
KeyError
as
e
:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_keys
=
True
compression_scheduler
.
load_state_dict
(
checkpoint
[
'
compression_sched
'
],
normalize_keys
)
msglogger
.
info
(
"
Loaded compression schedule from checkpoint (epoch {})
"
.
format
(
checkpoint_epoch
))
return
normalize_keys
def
_load_and_execute_thinning_recipes
():
msglogger
.
info
(
"
Loaded a thinning recipe from the checkpoint
"
)
# Cache the recipes in case we need them later
model
.
thinning_recipes
=
checkpoint
[
'
thinning_recipes
'
]
if
normalize_dataparallel_keys
:
model
.
thinning_recipes
=
[
distiller
.
get_normalized_recipe
(
recipe
)
for
recipe
in
model
.
thinning_recipes
]
distiller
.
execute_thinning_recipes_list
(
model
,
compression_scheduler
.
zeros_mask_dict
,
model
.
thinning_recipes
)
def
_load_optimizer
():
"""
Initialize optimizer with model parameters and load src_state_dict
"""
try
:
cls
,
src_state_dict
=
checkpoint
[
'
optimizer_type
'
],
checkpoint
[
'
optimizer_state_dict
'
]
# Initialize the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer
=
cls
(
model
.
parameters
(),
lr
=
1
)
dest_optimizer
.
load_state_dict
(
src_state_dict
)
msglogger
.
info
(
'
Optimizer of type {type} was loaded from checkpoint
'
.
format
(
type
=
type
(
dest_optimizer
)))
optimizer_param_groups
=
dest_optimizer
.
state_dict
()[
'
param_groups
'
]
msglogger
.
info
(
'
Optimizer Args: {}
'
.
format
(
dict
((
k
,
v
)
for
k
,
v
in
optimizer_param_groups
[
0
].
items
()
if
k
!=
'
params
'
)))
return
dest_optimizer
except
KeyError
:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
msglogger
.
warning
(
'
Optimizer could not be loaded from checkpoint.
'
)
return
None
if
not
os
.
path
.
isfile
(
chkpt_file
):
if
not
os
.
path
.
isfile
(
chkpt_file
):
raise
IOError
(
ENOENT
,
'
Could not find a checkpoint file at
'
,
chkpt_file
)
raise
IOError
(
ENOENT
,
'
Could not find a checkpoint file at
'
,
chkpt_file
)
...
@@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
...
@@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
normalize_dataparallel_keys
=
False
normalize_dataparallel_keys
=
False
if
'
compression_sched
'
in
checkpoint
:
if
'
compression_sched
'
in
checkpoint
:
compression_scheduler
=
distiller
.
CompressionScheduler
(
model
)
compression_scheduler
=
distiller
.
CompressionScheduler
(
model
)
try
:
normalize_dataparallel_keys
=
_load_compression_scheduler
()
compression_scheduler
.
load_state_dict
(
checkpoint
[
'
compression_sched
'
],
normalize_dataparallel_keys
)
except
KeyError
as
e
:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_dataparallel_keys
=
True
compression_scheduler
.
load_state_dict
(
checkpoint
[
'
compression_sched
'
],
normalize_dataparallel_keys
)
msglogger
.
info
(
"
Loaded compression schedule from checkpoint (epoch {})
"
.
format
(
checkpoint_epoch
))
else
:
else
:
msglogger
.
info
(
"
Warning: compression schedule data does not exist in the checkpoint
"
)
msglogger
.
info
(
"
Warning: compression schedule data does not exist in the checkpoint
"
)
if
'
thinning_recipes
'
in
checkpoint
:
if
'
thinning_recipes
'
in
checkpoint
:
if
'
compression_sched
'
not
in
checkpoint
:
if
not
compression_sched
uler
:
msglogger
.
warning
(
"
Found thinning_recipes key, but missing
mandatory
key compression_sched
"
)
msglogger
.
warning
(
"
Found thinning_recipes key, but missing key compression_sched
uler
"
)
compression_scheduler
=
distiller
.
CompressionScheduler
(
model
)
compression_scheduler
=
distiller
.
CompressionScheduler
(
model
)
msglogger
.
info
(
"
Loaded a thinning recipe from the checkpoint
"
)
_load_and_execute_thinning_recipes
()
# Cache the recipes in case we need them later
model
.
thinning_recipes
=
checkpoint
[
'
thinning_recipes
'
]
if
normalize_dataparallel_keys
:
model
.
thinning_recipes
=
[
distiller
.
get_normalized_recipe
(
recipe
)
for
recipe
in
model
.
thinning_recipes
]
distiller
.
execute_thinning_recipes_list
(
model
,
compression_scheduler
.
zeros_mask_dict
,
model
.
thinning_recipes
)
if
'
quantizer_metadata
'
in
checkpoint
:
if
'
quantizer_metadata
'
in
checkpoint
:
msglogger
.
info
(
'
Loaded quantizer metadata from the checkpoint
'
)
msglogger
.
info
(
'
Loaded quantizer metadata from the checkpoint
'
)
...
@@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
...
@@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
quantizer
.
prepare_model
(
qmd
[
'
dummy_input
'
])
quantizer
.
prepare_model
(
qmd
[
'
dummy_input
'
])
if
normalize_dataparallel_keys
:
if
normalize_dataparallel_keys
:
checkpoint
[
'
state_dict
'
]
=
{
normalize_module_name
(
k
):
v
for
k
,
v
in
checkpoint
[
'
state_dict
'
].
items
()}
checkpoint
[
'
state_dict
'
]
=
{
normalize_module_name
(
k
):
v
for
k
,
v
in
checkpoint
[
'
state_dict
'
].
items
()}
anomalous_keys
=
model
.
load_state_dict
(
checkpoint
[
'
state_dict
'
],
strict
)
anomalous_keys
=
model
.
load_state_dict
(
checkpoint
[
'
state_dict
'
],
strict
)
if
anomalous_keys
:
if
anomalous_keys
:
# This is pytorch 1.1+
# This is pytorch 1.1+
missing_keys
,
unexpected_keys
=
anomalous_keys
missing_keys
,
unexpected_keys
=
anomalous_keys
if
unexpected_keys
:
if
unexpected_keys
:
msglogger
.
warning
(
"
Warning: the loaded checkpoint (%s) contains %d unexpected state keys
"
%
(
chkpt_file
,
len
(
unexpected_keys
)))
msglogger
.
warning
(
"
Warning: the loaded checkpoint (%s) contains %d unexpected state keys
"
%
(
chkpt_file
,
len
(
unexpected_keys
)))
if
missing_keys
:
if
missing_keys
:
raise
ValueError
(
"
The loaded checkpoint (%s) is missing %d state keys
"
%
(
chkpt_file
,
len
(
missing_keys
)))
raise
ValueError
(
"
The loaded checkpoint (%s) is missing %d state keys
"
%
(
chkpt_file
,
len
(
missing_keys
)))
if
model_device
is
not
None
:
if
model_device
is
not
None
:
model
.
to
(
model_device
)
model
.
to
(
model_device
)
if
lean_checkpoint
:
if
lean_checkpoint
:
msglogger
.
info
(
"
=> loaded
'
state_dict
'
from checkpoint
'
{}
'"
.
format
(
str
(
chkpt_file
)))
msglogger
.
info
(
"
=> loaded
'
state_dict
'
from checkpoint
'
{}
'"
.
format
(
str
(
chkpt_file
)))
return
(
model
,
None
,
None
,
0
)
return
model
,
None
,
None
,
0
def
_load_optimizer
(
cls
,
src_state_dict
,
model
):
"""
Initiate optimizer with model parameters and load src_state_dict
"""
# initiate the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer
=
cls
(
model
.
parameters
(),
lr
=
1
)
dest_optimizer
.
load_state_dict
(
src_state_dict
)
return
dest_optimizer
try
:
optimizer
=
_load_optimizer
(
checkpoint
[
'
optimizer_type
'
],
checkpoint
[
'
optimizer_state_dict
'
],
model
)
except
KeyError
:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
optimizer
=
None
if
optimizer
is
not
None
:
msglogger
.
info
(
'
Optimizer of type {type} was loaded from checkpoint
'
.
format
(
type
=
type
(
optimizer
)))
msglogger
.
info
(
'
Optimizer Args: {}
'
.
format
(
dict
((
k
,
v
)
for
k
,
v
in
optimizer
.
state_dict
()[
'
param_groups
'
][
0
].
items
()
if
k
!=
'
params
'
)))
else
:
msglogger
.
warning
(
'
Optimizer could not be loaded from checkpoint.
'
)
optimizer
=
_load_optimizer
()
msglogger
.
info
(
"
=> loaded checkpoint
'
{f}
'
(epoch {e})
"
.
format
(
f
=
str
(
chkpt_file
),
msglogger
.
info
(
"
=> loaded checkpoint
'
{f}
'
(epoch {e})
"
.
format
(
f
=
str
(
chkpt_file
),
e
=
checkpoint_epoch
))
e
=
checkpoint_epoch
))
return
(
model
,
compression_scheduler
,
optimizer
,
start_epoch
)
return
model
,
compression_scheduler
,
optimizer
,
start_epoch
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment