Skip to content
Snippets Groups Projects
Commit 5b9bec84 authored by Neta Zmora's avatar Neta Zmora
Browse files

Code cleanup: PEP8 and dead code removal for compress_classifier.py

parent b21f449b
No related branches found
No related tags found
No related merge requests found
...@@ -56,7 +56,6 @@ import time ...@@ -56,7 +56,6 @@ import time
import os import os
import sys import sys
import random import random
import logging.config
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
...@@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn ...@@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torchnet.meter as tnt import torchnet.meter as tnt
try:
script_dir = os.path.dirname(__file__) import distiller
module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) except ImportError:
if module_path not in sys.path: script_dir = os.path.dirname(__file__)
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
sys.path.append(module_path) sys.path.append(module_path)
import distiller import distiller
import apputils import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
import distiller.quantization as quantization import distiller.quantization as quantization
from models import ALL_MODEL_NAMES, create_model from models import ALL_MODEL_NAMES, create_model
# Logger handle
msglogger = None msglogger = None
...@@ -209,7 +210,7 @@ def main(): ...@@ -209,7 +210,7 @@ def main():
# Create the model # Create the model
png_summary = args.summary is not None and args.summary.startswith('png') png_summary = args.summary is not None and args.summary.startswith('png')
is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible
model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus)
compression_scheduler = None compression_scheduler = None
...@@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
('Top1', classerr.value(1)), ('Top1', classerr.value(1)),
('Top5', classerr.value(5)), ('Top5', classerr.value(5)),
('LR', lr), ('LR', lr),
('Time', batch_time.mean)]) ('Time', batch_time.mean)]))
)
distiller.log_training_progress(stats, distiller.log_training_progress(stats,
model.named_parameters() if log_params_hist else None, model.named_parameters() if log_params_hist else None,
...@@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq): ...@@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq):
def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
"""Execute the validation/test loop.""" """Execute the validation/test loop."""
losses = {'objective_loss' : tnt.AverageValueMeter()} losses = {'objective_loss': tnt.AverageValueMeter()}
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
batch_time = tnt.AverageValueMeter() batch_time = tnt.AverageValueMeter()
# if nclasses<=10:
# # Log the confusion matrix only if the number of classes is small
# confusion = tnt.ConfusionMeter(10)
total_samples = len(data_loader.sampler) total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size batch_size = data_loader.batch_size
total_steps = total_samples / batch_size total_steps = total_samples / batch_size
...@@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): ...@@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
# measure accuracy and record loss # measure accuracy and record loss
losses['objective_loss'].add(loss.item()) losses['objective_loss'].add(loss.item())
classerr.add(output.data, target) classerr.add(output.data, target)
# if confusion:
# confusion.add(output.data, target)
# measure elapsed time # measure elapsed time
batch_time.add(time.time() - end) batch_time.add(time.time() - end)
...@@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): ...@@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n',
classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean) classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
# if confusion:
# msglogger.info('==> Confusion:\n%s', str(confusion.value()))
return classerr.value(1), classerr.value(5), losses['objective_loss'].mean return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment