diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 8ff0eb3f409c633433a1665041a635d664ae9020..486c84a850dde32134ceffe988ec4921f754281b 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -364,7 +364,7 @@ def main(): return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities) if args.evaluate: - return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args) + return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args, compression_scheduler) if args.compress: # The main use-case for this sample application is CNN compression. Compression @@ -722,7 +722,7 @@ def earlyexit_validate_stats(args): msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5) return(total_top1, total_top5, losses_exits_stats) -def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args): +def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args, scheduler=None): # This sample application can be invoked to evaluate the accuracy of your model on # the test dataset. # You can optionally quantize the model to 8-bit integer before evaluation. @@ -744,7 +744,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector if args.quantize_eval: checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, + apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, scheduler=scheduler, name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name, dir=msglogger.logdir)