diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 46c3b1314f746f62c926253e398466f3e031467e..076ddf941eae4b39e6d5b3428cbf1b09745047c1 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -51,7 +51,6 @@ models, or with the provided sample models: """ import math -import argparse import time import os import sys @@ -171,8 +170,7 @@ def main(): # We can optionally resume from a checkpoint if args.resume: - model, compression_scheduler, start_epoch = apputils.load_checkpoint( - model, chkpt_file=args.resume) + model, compression_scheduler, start_epoch = apputils.load_checkpoint(model, chkpt_file=args.resume) model.to(args.device) # Define loss function (criterion) and optimizer @@ -218,6 +216,15 @@ def main(): elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) + if args.thinnify: + #zeros_mask_dict = distiller.create_model_masks_dict(model) + assert args.resume is not None, "You must use --resume to provide a checkpoint file to thinnify" + distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) + apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler, + name="{}_thinned".format(args.resume.replace(".pth.tar", "")), dir=msglogger.logdir) + print("Note: your model may have collapsed to random inference, so you may want to fine-tune") + return + args.kd_policy = None if args.kd_teacher: teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus) diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py old mode 100644 new mode 100755 index f62174423bca586452cae1cb332ddde35786e57e..435c8a7a1e391df532bdfa51dafde3f9dd92a891 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -51,7 +51,6 @@ def getParser(): parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), - # choices=["train", "valid", "test"] help='collect activation statistics on phases: train, valid, and/or test' ' (WARNING: this slows down training)') parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, @@ -94,6 +93,8 @@ def getParser(): help='number of best scores to track and report (default: 1)') parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, help='Load a model without DataParallel wrapping it') + parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False, + help='physically remove zero-filters and create a smaller model') str_to_quant_mode_map = { 'sym': distiller.quantization.LinearQuantMode.SYMMETRIC,