Skip to content
Snippets Groups Projects
Commit e46e196e authored by Neta Zmora's avatar Neta Zmora
Browse files

compress_classifier.py: changed the signature of validate() and test()

Due to the various uses of these functions, we need to pass an ever growing
number of arguments to these functions and the API is becoing bloated and
unstable.

Also added the option to log the confusion matrix.
parent 4194aa77
No related branches found
No related tags found
No related merge requests found
......@@ -143,6 +143,8 @@ parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='
parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
help='Portion of training dataset to set aside for validation')
parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix')
def check_pytorch_version():
......@@ -209,6 +211,7 @@ def main():
# Infer the dataset from the model name
args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
args.num_classes = 10 if args.dataset == 'cifar10' else 1000
# Create the model
model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus)
......@@ -232,19 +235,24 @@ def main():
msglogger.info('Optimizer Args: %s', optimizer.defaults)
if args.ADC:
HAVE_GYM_INSTALLED = False
if not HAVE_GYM_INSTALLED:
import examples.automated_deep_compression.ADC as ADC
HAVE_COACH_INSTALLED = True
if not HAVE_COACH_INSTALLED:
raise ValueError("ADC is currently experimental and uses non-public Coach features")
import examples.automated_deep_compression.ADC as ADC
train_loader, val_loader, test_loader, _ = apputils.load_data(
args.dataset, os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_size, args.deterministic)
args.display_confusion = True
validate_fn = partial(validate, val_loader=test_loader, criterion=criterion,
loggers=[pylogger], print_freq=args.print_freq)
loggers=[pylogger], args=args)
if args.ADC_params is not None:
ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn)
exit()
save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, name='adc')
save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir)
ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn)
exit()
......@@ -278,7 +286,7 @@ def main():
# model. The ouptut is saved to CSV and PNG.
msglogger.info("Running sensitivity tests")
test_fnc = partial(test, test_loader=test_loader, criterion=criterion,
loggers=[pylogger], print_freq=args.print_freq)
loggers=[pylogger], args=args)
which_params = [param_name for param_name, _ in model.named_parameters()]
sensitivity = distiller.perform_sensitivity_analysis(model,
net_params=which_params,
......@@ -300,7 +308,7 @@ def main():
quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
quantizer.prepare_model()
model.cuda()
top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq)
top1, _, _ = test(test_loader, model, criterion, [pylogger], args=args)
if args.quantize:
checkpoint_name = 'quantized'
apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
......@@ -315,8 +323,6 @@ def main():
# Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
model.cuda()
best_epoch = start_epoch
for epoch in range(start_epoch, start_epoch + args.epochs):
# This is the main training loop.
msglogger.info('\n')
......@@ -325,14 +331,14 @@ def main():
# Train for one epoch
train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
loggers=[tflogger, pylogger], print_freq=args.print_freq, log_params_hist=args.log_params_histograms)
loggers=[tflogger, pylogger], args=args)
distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
if args.activation_stats:
distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger],
collector=activations_sparsity)
# evaluate on validation set
top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args.print_freq, epoch)
top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
stats = ('Peformance/Validation/',
OrderedDict([('Loss', vloss),
('Top1', top1),
......@@ -353,11 +359,11 @@ def main():
args.name, msglogger.logdir)
# Finally run results on the test set
test(test_loader, model, criterion, [pylogger], args.print_freq)
test(test_loader, model, criterion, [pylogger], args=args)
def train(train_loader, model, criterion, optimizer, epoch,
compression_scheduler, loggers, print_freq, log_params_hist):
compression_scheduler, loggers, args):
"""Training loop for one epoch."""
losses = {'objective_loss': tnt.AverageValueMeter(),
'regularizer_loss': tnt.AverageValueMeter()}
......@@ -413,7 +419,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
batch_time.add(time.time() - end)
steps_completed = (train_step+1)
if steps_completed % print_freq == 0:
if steps_completed % args.print_freq == 0:
# Log some statistics
lr = optimizer.param_groups[0]['lr']
stats = ('Peformance/Training/',
......@@ -425,36 +431,39 @@ def train(train_loader, model, criterion, optimizer, epoch,
('LR', lr),
('Time', batch_time.mean)]))
params = model.named_parameters() if args.log_params_histograms else None
distiller.log_training_progress(stats,
model.named_parameters() if log_params_hist else None,
params,
epoch, steps_completed,
steps_per_epoch, print_freq,
steps_per_epoch, args.print_freq,
loggers)
end = time.time()
def validate(val_loader, model, criterion, loggers, print_freq, epoch=-1):
def validate(val_loader, model, criterion, loggers, args, epoch=-1):
"""Model validation"""
if epoch > -1:
msglogger.info('--- validate (epoch=%d)-----------', epoch)
else:
msglogger.info('--- validate ---------------------')
return _validate(val_loader, model, criterion, loggers, print_freq, epoch)
return _validate(val_loader, model, criterion, loggers, args, epoch)
def test(test_loader, model, criterion, loggers, print_freq):
def test(test_loader, model, criterion, loggers, args):
"""Model Test"""
msglogger.info('--- test ---------------------')
return _validate(test_loader, model, criterion, loggers, print_freq)
return _validate(test_loader, model, criterion, loggers, args)
def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
"""Execute the validation/test loop."""
losses = {'objective_loss': tnt.AverageValueMeter()}
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
batch_time = tnt.AverageValueMeter()
total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size
if args.display_confusion:
confusion = tnt.ConfusionMeter(args.num_classes)
total_steps = total_samples / batch_size
msglogger.info('%d samples (%d per mini-batch)', total_samples, batch_size)
......@@ -475,22 +484,27 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1):
# measure accuracy and record loss
losses['objective_loss'].add(loss.item())
classerr.add(output.data, target)
if args.display_confusion:
confusion.add(output.data, target)
# measure elapsed time
batch_time.add(time.time() - end)
end = time.time()
steps_completed = (validation_step+1)
if steps_completed % print_freq == 0:
if steps_completed % args.print_freq == 0:
stats = ('',
OrderedDict([('Loss', losses['objective_loss'].mean),
('Top1', classerr.value(1)),
('Top5', classerr.value(5))]))
distiller.log_training_progress(stats, None, epoch, steps_completed,
total_steps, print_freq, loggers)
total_steps, args.print_freq, loggers)
msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n',
classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
if args.display_confusion:
msglogger.info('==> Confusion:\n%s', str(confusion.value()))
return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment