Skip to content
Snippets Groups Projects
Commit 5b9bec84 authored by Neta Zmora's avatar Neta Zmora
Browse files

Code cleanup: PEP8 and dead code removal for compress_classifier.py

parent b21f449b
No related branches found
No related tags found
No related merge requests found
......@@ -56,7 +56,6 @@ import time
import os
import sys
import random
import logging.config
import traceback
from collections import OrderedDict
from functools import partial
......@@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchnet.meter as tnt
script_dir = os.path.dirname(__file__)
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
if module_path not in sys.path:
try:
import distiller
except ImportError:
script_dir = os.path.dirname(__file__)
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
sys.path.append(module_path)
import distiller
import distiller
import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
import distiller.quantization as quantization
from models import ALL_MODEL_NAMES, create_model
# Logger handle
msglogger = None
......@@ -209,7 +210,7 @@ def main():
# Create the model
png_summary = args.summary is not None and args.summary.startswith('png')
is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible
is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible
model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus)
compression_scheduler = None
......@@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
('Top1', classerr.value(1)),
('Top5', classerr.value(5)),
('LR', lr),
('Time', batch_time.mean)])
)
('Time', batch_time.mean)]))
distiller.log_training_progress(stats,
model.named_parameters() if log_params_hist else None,
......@@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq):
def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
"""Execute the validation/test loop."""
losses = {'objective_loss' : tnt.AverageValueMeter()}
losses = {'objective_loss': tnt.AverageValueMeter()}
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
batch_time = tnt.AverageValueMeter()
# if nclasses<=10:
# # Log the confusion matrix only if the number of classes is small
# confusion = tnt.ConfusionMeter(10)
total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size
total_steps = total_samples / batch_size
......@@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
# measure accuracy and record loss
losses['objective_loss'].add(loss.item())
classerr.add(output.data, target)
# if confusion:
# confusion.add(output.data, target)
# measure elapsed time
batch_time.add(time.time() - end)
......@@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n',
classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
# if confusion:
# msglogger.info('==> Confusion:\n%s', str(confusion.value()))
return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment