Skip to content
Snippets Groups Projects
Commit c25d9ee2 authored by Yi-Syuan Chen's avatar Yi-Syuan Chen Committed by Guy Jacob
Browse files

Save scheduler in quantize_eval checkpoint (#99)

parent 8bcaaa53
No related branches found
No related tags found
No related merge requests found
......@@ -364,7 +364,7 @@ def main():
return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)
if args.evaluate:
return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args, compression_scheduler)
if args.compress:
# The main use-case for this sample application is CNN compression. Compression
......@@ -722,7 +722,7 @@ def earlyexit_validate_stats(args):
msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
return(total_top1, total_top5, losses_exits_stats)
def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args, scheduler=None):
# This sample application can be invoked to evaluate the accuracy of your model on
# the test dataset.
# You can optionally quantize the model to 8-bit integer before evaluation.
......@@ -744,7 +744,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
if args.quantize_eval:
checkpoint_name = 'quantized'
apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, scheduler=scheduler,
name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name,
dir=msglogger.logdir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment