diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index c6db0b880ba27097f38de87ce7aa5441f26190cb..b92997259259ac5cb647d51934e81fba72e1909d 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -56,7 +56,6 @@ import time import os import sys import random -import logging.config import traceback from collections import OrderedDict from functools import partial @@ -68,17 +67,19 @@ import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torchnet.meter as tnt - -script_dir = os.path.dirname(__file__) -module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) -if module_path not in sys.path: +try: + import distiller +except ImportError: + script_dir = os.path.dirname(__file__) + module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) sys.path.append(module_path) -import distiller + import distiller import apputils from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector import distiller.quantization as quantization from models import ALL_MODEL_NAMES, create_model +# Logger handle msglogger = None @@ -209,7 +210,7 @@ def main(): # Create the model png_summary = args.summary is not None and args.summary.startswith('png') - is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible + is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) compression_scheduler = None @@ -399,8 +400,7 @@ def train(train_loader, model, criterion, optimizer, epoch, ('Top1', classerr.value(1)), ('Top5', classerr.value(5)), ('LR', lr), - ('Time', batch_time.mean)]) - ) + ('Time', batch_time.mean)])) distiller.log_training_progress(stats, model.named_parameters() if log_params_hist else None, @@ -427,13 +427,9 @@ def test(test_loader, model, criterion, loggers, print_freq): def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): """Execute the validation/test loop.""" - losses = {'objective_loss' : tnt.AverageValueMeter()} + losses = {'objective_loss': tnt.AverageValueMeter()} classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) batch_time = tnt.AverageValueMeter() - # if nclasses<=10: - # # Log the confusion matrix only if the number of classes is small - # confusion = tnt.ConfusionMeter(10) - total_samples = len(data_loader.sampler) batch_size = data_loader.batch_size total_steps = total_samples / batch_size @@ -456,8 +452,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): # measure accuracy and record loss losses['objective_loss'].add(loss.item()) classerr.add(output.data, target) - # if confusion: - # confusion.add(output.data, target) # measure elapsed time batch_time.add(time.time() - end) @@ -474,9 +468,6 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean) - - # if confusion: - # msglogger.info('==> Confusion:\n%s', str(confusion.value())) return classerr.value(1), classerr.value(5), losses['objective_loss'].mean