diff --git a/README.md b/README.md index 466a8b80791962ac93bf5e91c5cd9d2f84ab348f..31663d2feb7e3369dccf6b3529d2790076e367de 100755 --- a/README.md +++ b/README.md @@ -164,6 +164,9 @@ Beware. - Training with [knowledge distillation](https://nervanasystems.github.io/distiller/knowledge_distillation.html), in conjunction with the other available pruning / regularization / quantization methods. * **Conditional computation** - Sample implementation of Early Exit +* **Low rank decomposition** + - Sample implementation of [truncated SVD](https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb) +* Lottery Ticket Hypothesis training * Export statistics summaries using Pandas dataframes, which makes it easy to slice, query, display and graph the data. * A set of [Jupyter notebooks](https://nervanasystems.github.io/distiller/jupyter/index.html) to plan experiments and analyze compression results. The graphs and visualizations you see on this page originate from the included Jupyter notebooks. + Take a look at [this notebook](https://github.com/NervanaSystems/distiller/blob/master/jupyter/alexnet_insights.ipynb), which compares visual aspects of dense and sparse Alexnet models. diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index cad734bae25687ed1bd83bf1bcd395e4adc3e248..0385d0eee9e451ee699de01ac019abb26253d662 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -21,6 +21,8 @@ import distiller.models as models def add_cmdline_args(parser): + parser.add_argument('--save-untrained-model', action='store_true', default=False, + help='Save the randomly-initialized model before training (useful for lottery-ticket method)') parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)') parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, diff --git a/examples/lottery_ticket/README.md b/examples/lottery_ticket/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9497391896b6570a283f424591b94cddc7bf5a1 --- /dev/null +++ b/examples/lottery_ticket/README.md @@ -0,0 +1,52 @@ +## Lottery Ticket Hypothesis + +>The Lottery Ticket Hypothesis: A randomly-initialized, dense neural network contains a subnetwork that is initialized +such that — when trained in isolation — it can match the test accuracy of the original network after training for at +most the same number of iterations." + +### Finding winning tickets +> We identify winning tickets by training networks and subsequently pruning +their smallest-magnitude weights. The set of connections that survives this process is the architecture +of a winning ticket. Unique to our work, the winning ticket’s weights are the values to which these +connections were initialized before training. This forms our central experiment: +>1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0). +>2. Train the network for j iterations, reaching parameters theta_j . +>3. Prune s% of the parameters, creating a mask m where Pm = (100 - s)%. +>4. To extract the winning ticket, reset the remaining parameters to their values intheta_0, creating +the untrained network f(x;m *theta_0). + +### Example +Train a ResNet20-CIFAR10 network from scratch, and save the untrained, randomized initial network weights in a checkpoint file. +To do this, you use the `--save-untrained-model` flag: <br> +```bash +python3 compress_classifier.py --arch resnet20_cifar ${CIFAR10_PATH} -p=50 --epochs=110 --compress=../ssl/resnet20_cifar_baseline_training.yaml --vs=0 --gpus=0 -j=4 --lr=0.4 --name=resnet20 --save-untrained-model +``` + +After training the network, we have two outputs: the best trained network (`resnet20_best.pth.tar`) and the initial untrained network (`resnet20_untrained_checkpoint.pth.tar`).<br> +In this example, we copy them into the `examples/lottery_ticket` directory for convenience. + +```bash +cp logs/resnet20___2019.08.22-220243/resnet20_best.pth.tar ../lottery_ticket/ +cp logs/resnet20___2019.08.22-220243/resnet20_untrained_checkpoint.pth.tar ../lottery_ticket/ +``` + +We then prune our best trained ResNet20 network and copy the result into `examples/lottery_ticket` as well. +```bash +python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../lottery_ticket/resnet20_best.pth.tar --vs=0 --reset-optimizer --gpus=0 +cp logs/2019.08.22-222752/best.pth.tar ../lottery_ticket/resnet20_pruned.pth.tar +``` + +Next, we run ```lottery.py``` to extract the winning ticket. +```bash +python lottery.py --lt-untrained-ckpt=resnet20_untrained_checkpoint.pth.tar --lt-pruned-ckpt=resnet20_pruned.pth.tar +``` + +Finally, we train the winning ticket. +```bash +python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --resume-from=../lottery_ticket/resnet20_untrained_checkpoint.pth.tar_lottery_checkpoint.pth.tar --vs=0 --reset-optimizer --gpus=0 --compress=../ssl/resnet20_cifar_baseline_training.yaml +``` + + +[1] Jonathan Frankle, Michael Carbin<br> + The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks<br> + arXiv:1803.03635 \ No newline at end of file diff --git a/examples/lottery_ticket/__init__.py b/examples/lottery_ticket/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/lottery_ticket/lottery.py b/examples/lottery_ticket/lottery.py new file mode 100644 index 0000000000000000000000000000000000000000..7e61545edfe14d2d4bddac121c792be5b99d0661 --- /dev/null +++ b/examples/lottery_ticket/lottery.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Lottery Ticket Hypothesis""" + + +import logging +import distiller.apputils as apputils +import distiller.models +import torch +import argparse +from tabulate import tabulate +import distiller + + +def print_sparsities(masks_dict): + mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() + if mask is not None] + print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid")) + + +def add_args(argparser): + """ + Helper function which defines command-line arguments specific to Lottery Ticket Hypothesis training. + + Arguments: + argparser (argparse.ArgumentParser): Existing parser to which to add the arguments + """ + group = argparser.add_argument_group('AutoML Compression Arguments') + group.add_argument('--lt-untrained-ckpt', type=str, action='store', + help='Checkpoint file of the untrained network (randomly initialized)') + group.add_argument('--lt-pruned-ckpt', type=str, action='store', + help='Checkpoint file of the pruned (but not thinned) network') + return argparser + + +def extract_lottery_ticket(args, untrained_ckpt_name, pruned_ckpt_name): + untrained_ckpt = apputils.load_checkpoint(model=None, chkpt_file=untrained_ckpt_name, model_device='cpu') + untrained_model, _, optimizer, start_epoch = untrained_ckpt + + pruned_ckpt = apputils.load_checkpoint(model=None, chkpt_file=pruned_ckpt_name, model_device='cpu') + pruned_model, pruned_scheduler, optimizer, start_epoch = pruned_ckpt + + # create a dictionary of masks by inferring the masks from the parameter sparsity + masks_dict = {pname: (torch.ne(param, 0)).type(param.type()) + for pname, param in pruned_model.named_parameters() + if pname in pruned_scheduler.zeros_mask_dict.keys()} + for pname, mask in masks_dict.items(): + untrained_model.state_dict()[pname].mul_(mask) + + sparsities = {pname: distiller.sparsity(mask) for pname, mask in masks_dict.items()} + print(sparsities) + pruned_scheduler.init_from_masks_dict(masks_dict) + + apputils.save_checkpoint(0, pruned_model.arch, untrained_model, optimizer=optimizer, + scheduler=pruned_scheduler, + name='_'.join([untrained_ckpt_name, 'masked'])) + + # pruned_ckpt = torch.load(pruned_ckpt_name, map_location='cpu') + # + # assert 'extras' in pruned_ckpt and pruned_ckpt['extras'] + # print("\nContents of Checkpoint['extras']:") + # print(get_contents_table(pruned_ckpt['extras'])) + # masks_dict = pruned_ckpt["extras"]["creation_masks"] + # print_sparsities(masks_dict) + # compression_scheduler = distiller.CompressionScheduler(model) + # compression_scheduler.load_state_dict(masks_dict) + # + # model = distiller.models.create_model(False, args.dataset, args.arch, device_ids=[-1]) + # model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device) + # + # apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=scheduler, + # name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name, + # dir=msglogger.logdir, extras={'quantized_top1': top1}) + +msglogger = logging.getLogger() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Lottery Ticket Hypothesis') + add_args(parser) + args = parser.parse_args() + extract_lottery_ticket(args, args.lt_untrained_ckpt, args.lt_pruned_ckpt) \ No newline at end of file