Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
distiller
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
distiller
Commits
5b9bec84
Commit
5b9bec84
authored
6 years ago
by
Neta Zmora
Browse files
Options
Downloads
Patches
Plain Diff
Code cleanup: PEP8 and dead code removal for compress_classifier.py
parent
b21f449b
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
examples/classifier_compression/compress_classifier.py
+10
-19
10 additions, 19 deletions
examples/classifier_compression/compress_classifier.py
with
10 additions
and
19 deletions
examples/classifier_compression/compress_classifier.py
+
10
−
19
View file @
5b9bec84
...
@@ -56,7 +56,6 @@ import time
...
@@ -56,7 +56,6 @@ import time
import
os
import
os
import
sys
import
sys
import
random
import
random
import
logging.config
import
traceback
import
traceback
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
functools
import
partial
from
functools
import
partial
...
@@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn
...
@@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn
import
torch.optim
import
torch.optim
import
torch.utils.data
import
torch.utils.data
import
torchnet.meter
as
tnt
import
torchnet.meter
as
tnt
try
:
script_dir
=
os
.
path
.
dirname
(
__file__
)
import
distiller
module_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
script_dir
,
'
..
'
,
'
..
'
))
except
ImportError
:
if
module_path
not
in
sys
.
path
:
script_dir
=
os
.
path
.
dirname
(
__file__
)
module_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
script_dir
,
'
..
'
,
'
..
'
))
sys
.
path
.
append
(
module_path
)
sys
.
path
.
append
(
module_path
)
import
distiller
import
distiller
import
apputils
import
apputils
from
distiller.data_loggers
import
TensorBoardLogger
,
PythonLogger
,
ActivationSparsityCollector
from
distiller.data_loggers
import
TensorBoardLogger
,
PythonLogger
,
ActivationSparsityCollector
import
distiller.quantization
as
quantization
import
distiller.quantization
as
quantization
from
models
import
ALL_MODEL_NAMES
,
create_model
from
models
import
ALL_MODEL_NAMES
,
create_model
# Logger handle
msglogger
=
None
msglogger
=
None
...
@@ -209,7 +210,7 @@ def main():
...
@@ -209,7 +210,7 @@ def main():
# Create the model
# Create the model
png_summary
=
args
.
summary
is
not
None
and
args
.
summary
.
startswith
(
'
png
'
)
png_summary
=
args
.
summary
is
not
None
and
args
.
summary
.
startswith
(
'
png
'
)
is_parallel
=
not
png_summary
and
args
.
summary
!=
'
compute
'
# For PNG summary, parallel graphs are illegible
is_parallel
=
not
png_summary
and
args
.
summary
!=
'
compute
'
# For PNG summary, parallel graphs are illegible
model
=
create_model
(
args
.
pretrained
,
args
.
dataset
,
args
.
arch
,
parallel
=
is_parallel
,
device_ids
=
args
.
gpus
)
model
=
create_model
(
args
.
pretrained
,
args
.
dataset
,
args
.
arch
,
parallel
=
is_parallel
,
device_ids
=
args
.
gpus
)
compression_scheduler
=
None
compression_scheduler
=
None
...
@@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
...
@@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
(
'
Top1
'
,
classerr
.
value
(
1
)),
(
'
Top1
'
,
classerr
.
value
(
1
)),
(
'
Top5
'
,
classerr
.
value
(
5
)),
(
'
Top5
'
,
classerr
.
value
(
5
)),
(
'
LR
'
,
lr
),
(
'
LR
'
,
lr
),
(
'
Time
'
,
batch_time
.
mean
)])
(
'
Time
'
,
batch_time
.
mean
)]))
)
distiller
.
log_training_progress
(
stats
,
distiller
.
log_training_progress
(
stats
,
model
.
named_parameters
()
if
log_params_hist
else
None
,
model
.
named_parameters
()
if
log_params_hist
else
None
,
...
@@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq):
...
@@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq):
def
_validate
(
data_loader
,
model
,
criterion
,
loggers
,
print_freq
,
epoch
=-
1
):
def
_validate
(
data_loader
,
model
,
criterion
,
loggers
,
print_freq
,
epoch
=-
1
):
"""
Execute the validation/test loop.
"""
"""
Execute the validation/test loop.
"""
losses
=
{
'
objective_loss
'
:
tnt
.
AverageValueMeter
()}
losses
=
{
'
objective_loss
'
:
tnt
.
AverageValueMeter
()}
classerr
=
tnt
.
ClassErrorMeter
(
accuracy
=
True
,
topk
=
(
1
,
5
))
classerr
=
tnt
.
ClassErrorMeter
(
accuracy
=
True
,
topk
=
(
1
,
5
))
batch_time
=
tnt
.
AverageValueMeter
()
batch_time
=
tnt
.
AverageValueMeter
()
# if nclasses<=10:
# # Log the confusion matrix only if the number of classes is small
# confusion = tnt.ConfusionMeter(10)
total_samples
=
len
(
data_loader
.
sampler
)
total_samples
=
len
(
data_loader
.
sampler
)
batch_size
=
data_loader
.
batch_size
batch_size
=
data_loader
.
batch_size
total_steps
=
total_samples
/
batch_size
total_steps
=
total_samples
/
batch_size
...
@@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
...
@@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
# measure accuracy and record loss
# measure accuracy and record loss
losses
[
'
objective_loss
'
].
add
(
loss
.
item
())
losses
[
'
objective_loss
'
].
add
(
loss
.
item
())
classerr
.
add
(
output
.
data
,
target
)
classerr
.
add
(
output
.
data
,
target
)
# if confusion:
# confusion.add(output.data, target)
# measure elapsed time
# measure elapsed time
batch_time
.
add
(
time
.
time
()
-
end
)
batch_time
.
add
(
time
.
time
()
-
end
)
...
@@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
...
@@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
msglogger
.
info
(
'
==> Top1: %.3f Top5: %.3f Loss: %.3f
\n
'
,
msglogger
.
info
(
'
==> Top1: %.3f Top5: %.3f Loss: %.3f
\n
'
,
classerr
.
value
()[
0
],
classerr
.
value
()[
1
],
losses
[
'
objective_loss
'
].
mean
)
classerr
.
value
()[
0
],
classerr
.
value
()[
1
],
losses
[
'
objective_loss
'
].
mean
)
# if confusion:
# msglogger.info('==> Confusion:\n%s', str(confusion.value()))
return
classerr
.
value
(
1
),
classerr
.
value
(
5
),
losses
[
'
objective_loss
'
].
mean
return
classerr
.
value
(
1
),
classerr
.
value
(
5
),
losses
[
'
objective_loss
'
].
mean
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment