Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
d28587cd
Commit
d28587cd
authored
5 years ago
by
Neta Zmora
Browse files
Options
Downloads
Patches
Plain Diff
image_classifier.py: lazily init data-loaders when calling validate/test
parent
54e72012
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
distiller/apputils/image_classifier.py
+14
-7
14 additions, 7 deletions
distiller/apputils/image_classifier.py
with
14 additions
and
7 deletions
distiller/apputils/image_classifier.py
+
14
−
7
View file @
d28587cd
...
@@ -55,7 +55,7 @@ class ClassifierCompressor(object):
...
@@ -55,7 +55,7 @@ class ClassifierCompressor(object):
"""
"""
def
__init__
(
self
,
args
,
script_dir
):
def
__init__
(
self
,
args
,
script_dir
):
self
.
args
=
args
self
.
args
=
args
_
override
_args
(
args
)
_
infer_implicit
_args
(
args
)
self
.
logdir
=
_init_logger
(
args
,
script_dir
)
self
.
logdir
=
_init_logger
(
args
,
script_dir
)
_config_determinism
(
args
)
_config_determinism
(
args
)
_config_compute_device
(
args
)
_config_compute_device
(
args
)
...
@@ -158,10 +158,14 @@ class ClassifierCompressor(object):
...
@@ -158,10 +158,14 @@ class ClassifierCompressor(object):
self
.
_finalize_epoch
(
epoch
,
perf_scores_history
,
top1
,
top5
)
self
.
_finalize_epoch
(
epoch
,
perf_scores_history
,
top1
,
top5
)
def
validate
(
self
,
epoch
=-
1
):
def
validate
(
self
,
epoch
=-
1
):
if
self
.
val_loader
is
None
:
self
.
load_datasets
()
return
validate
(
self
.
val_loader
,
self
.
model
,
self
.
criterion
,
return
validate
(
self
.
val_loader
,
self
.
model
,
self
.
criterion
,
[
self
.
tflogger
,
self
.
pylogger
],
self
.
args
,
epoch
)
[
self
.
tflogger
,
self
.
pylogger
],
self
.
args
,
epoch
)
def
test
(
self
):
def
test
(
self
):
if
self
.
test_loader
is
None
:
self
.
load_datasets
()
return
test
(
self
.
test_loader
,
self
.
model
,
self
.
criterion
,
return
test
(
self
.
test_loader
,
self
.
model
,
self
.
criterion
,
self
.
pylogger
,
self
.
activations_collectors
,
args
=
self
.
args
)
self
.
pylogger
,
self
.
activations_collectors
,
args
=
self
.
args
)
...
@@ -329,7 +333,6 @@ def _config_compute_device(args):
...
@@ -329,7 +333,6 @@ def _config_compute_device(args):
if
args
.
gpus
is
not
None
:
if
args
.
gpus
is
not
None
:
try
:
try
:
args
.
gpus
=
[
int
(
s
)
for
s
in
args
.
gpus
.
split
(
'
,
'
)]
args
.
gpus
=
[
int
(
s
)
for
s
in
args
.
gpus
.
split
(
'
,
'
)]
except
ValueError
:
except
ValueError
:
raise
ValueError
(
'
ERROR: Argument --gpus must be a comma-separated list of integers only
'
)
raise
ValueError
(
'
ERROR: Argument --gpus must be a comma-separated list of integers only
'
)
available_gpus
=
torch
.
cuda
.
device_count
()
available_gpus
=
torch
.
cuda
.
device_count
()
...
@@ -341,10 +344,12 @@ def _config_compute_device(args):
...
@@ -341,10 +344,12 @@ def _config_compute_device(args):
torch
.
cuda
.
set_device
(
args
.
gpus
[
0
])
torch
.
cuda
.
set_device
(
args
.
gpus
[
0
])
def
_
override
_args
(
args
):
def
_
infer_implicit
_args
(
args
):
# Infer the dataset from the model name
# Infer the dataset from the model name
args
.
dataset
=
distiller
.
apputils
.
classification_dataset_str_from_arch
(
args
.
arch
)
if
not
hasattr
(
args
,
'
dataset
'
):
args
.
num_classes
=
distiller
.
apputils
.
classification_num_classes
(
args
.
dataset
)
args
.
dataset
=
distiller
.
apputils
.
classification_dataset_str_from_arch
(
args
.
arch
)
if
not
hasattr
(
args
,
"
num_classes
"
):
args
.
num_classes
=
distiller
.
apputils
.
classification_num_classes
(
args
.
dataset
)
def
_init_learner
(
args
):
def
_init_learner
(
args
):
...
@@ -376,8 +381,8 @@ def _init_learner(args):
...
@@ -376,8 +381,8 @@ def _init_learner(args):
msglogger
.
info
(
'
\n
reset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0
'
)
msglogger
.
info
(
'
\n
reset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0
'
)
if
optimizer
is
None
:
if
optimizer
is
None
:
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
msglogger
.
debug
(
'
Optimizer Type: %s
'
,
type
(
optimizer
))
msglogger
.
debug
(
'
Optimizer Type: %s
'
,
type
(
optimizer
))
msglogger
.
debug
(
'
Optimizer Args: %s
'
,
optimizer
.
defaults
)
msglogger
.
debug
(
'
Optimizer Args: %s
'
,
optimizer
.
defaults
)
...
@@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
...
@@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
# NOTE: this breaks previous behavior, which returned a history of (top1, top5) values
# NOTE: this breaks previous behavior, which returned a history of (top1, top5) values
return
classerr
.
value
(
1
),
classerr
.
value
(
5
),
losses
[
OVERALL_LOSS_KEY
]
return
classerr
.
value
(
1
),
classerr
.
value
(
5
),
losses
[
OVERALL_LOSS_KEY
]
def
validate
(
val_loader
,
model
,
criterion
,
loggers
,
args
,
epoch
=-
1
):
def
validate
(
val_loader
,
model
,
criterion
,
loggers
,
args
,
epoch
=-
1
):
"""
Model validation
"""
"""
Model validation
"""
if
epoch
>
-
1
:
if
epoch
>
-
1
:
...
@@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args):
...
@@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args):
def
_is_earlyexit
(
args
):
def
_is_earlyexit
(
args
):
return
hasattr
(
args
,
'
earlyexit_thresholds
'
)
and
args
.
earlyexit_thresholds
return
hasattr
(
args
,
'
earlyexit_thresholds
'
)
and
args
.
earlyexit_thresholds
def
_validate
(
data_loader
,
model
,
criterion
,
loggers
,
args
,
epoch
=-
1
):
def
_validate
(
data_loader
,
model
,
criterion
,
loggers
,
args
,
epoch
=-
1
):
"""
Execute the validation/test loop.
"""
"""
Execute the validation/test loop.
"""
losses
=
{
'
objective_loss
'
:
tnt
.
AverageValueMeter
()}
losses
=
{
'
objective_loss
'
:
tnt
.
AverageValueMeter
()}
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment