Skip to content
Snippets Groups Projects
Commit 78e2e4c7 authored by Neta Zmora's avatar Neta Zmora
Browse files

Lottery Ticket Hypothesis

Added support for saving the randomly initialized network before
starting training; and added an implmentation showing how to extract
a (winning) lottery ticket from the prestine network, and the
pruned network.
parent 6c18a820
No related branches found
No related tags found
No related merge requests found
......@@ -164,6 +164,9 @@ Beware.
- Training with [knowledge distillation](https://nervanasystems.github.io/distiller/knowledge_distillation.html), in conjunction with the other available pruning / regularization / quantization methods.
* **Conditional computation**
- Sample implementation of Early Exit
* **Low rank decomposition**
- Sample implementation of [truncated SVD](https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb)
* Lottery Ticket Hypothesis training
* Export statistics summaries using Pandas dataframes, which makes it easy to slice, query, display and graph the data.
* A set of [Jupyter notebooks](https://nervanasystems.github.io/distiller/jupyter/index.html) to plan experiments and analyze compression results. The graphs and visualizations you see on this page originate from the included Jupyter notebooks.
+ Take a look at [this notebook](https://github.com/NervanaSystems/distiller/blob/master/jupyter/alexnet_insights.ipynb), which compares visual aspects of dense and sparse Alexnet models.
......
......@@ -21,6 +21,8 @@ import distiller.models as models
def add_cmdline_args(parser):
parser.add_argument('--save-untrained-model', action='store_true', default=False,
help='Save the randomly-initialized model before training (useful for lottery-ticket method)')
parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
......
## Lottery Ticket Hypothesis
>The Lottery Ticket Hypothesis: A randomly-initialized, dense neural network contains a subnetwork that is initialized
such that — when trained in isolation — it can match the test accuracy of the original network after training for at
most the same number of iterations."
### Finding winning tickets
> We identify winning tickets by training networks and subsequently pruning
their smallest-magnitude weights. The set of connections that survives this process is the architecture
of a winning ticket. Unique to our work, the winning ticket’s weights are the values to which these
connections were initialized before training. This forms our central experiment:
>1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0).
>2. Train the network for j iterations, reaching parameters theta_j .
>3. Prune s% of the parameters, creating a mask m where Pm = (100 - s)%.
>4. To extract the winning ticket, reset the remaining parameters to their values intheta_0, creating
the untrained network f(x;m *theta_0).
### Example
Train a ResNet20-CIFAR10 network from scratch, and save the untrained, randomized initial network weights in a checkpoint file.
To do this, you use the `--save-untrained-model` flag: <br>
```bash
python3 compress_classifier.py --arch resnet20_cifar ${CIFAR10_PATH} -p=50 --epochs=110 --compress=../ssl/resnet20_cifar_baseline_training.yaml --vs=0 --gpus=0 -j=4 --lr=0.4 --name=resnet20 --save-untrained-model
```
After training the network, we have two outputs: the best trained network (`resnet20_best.pth.tar`) and the initial untrained network (`resnet20_untrained_checkpoint.pth.tar`).<br>
In this example, we copy them into the `examples/lottery_ticket` directory for convenience.
```bash
cp logs/resnet20___2019.08.22-220243/resnet20_best.pth.tar ../lottery_ticket/
cp logs/resnet20___2019.08.22-220243/resnet20_untrained_checkpoint.pth.tar ../lottery_ticket/
```
We then prune our best trained ResNet20 network and copy the result into `examples/lottery_ticket` as well.
```bash
python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../lottery_ticket/resnet20_best.pth.tar --vs=0 --reset-optimizer --gpus=0
cp logs/2019.08.22-222752/best.pth.tar ../lottery_ticket/resnet20_pruned.pth.tar
```
Next, we run ```lottery.py``` to extract the winning ticket.
```bash
python lottery.py --lt-untrained-ckpt=resnet20_untrained_checkpoint.pth.tar --lt-pruned-ckpt=resnet20_pruned.pth.tar
```
Finally, we train the winning ticket.
```bash
python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --resume-from=../lottery_ticket/resnet20_untrained_checkpoint.pth.tar_lottery_checkpoint.pth.tar --vs=0 --reset-optimizer --gpus=0 --compress=../ssl/resnet20_cifar_baseline_training.yaml
```
[1] Jonathan Frankle, Michael Carbin<br>
The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks<br>
arXiv:1803.03635
\ No newline at end of file
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Lottery Ticket Hypothesis"""
import logging
import distiller.apputils as apputils
import distiller.models
import torch
import argparse
from tabulate import tabulate
import distiller
def print_sparsities(masks_dict):
mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
if mask is not None]
print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
def add_args(argparser):
"""
Helper function which defines command-line arguments specific to Lottery Ticket Hypothesis training.
Arguments:
argparser (argparse.ArgumentParser): Existing parser to which to add the arguments
"""
group = argparser.add_argument_group('AutoML Compression Arguments')
group.add_argument('--lt-untrained-ckpt', type=str, action='store',
help='Checkpoint file of the untrained network (randomly initialized)')
group.add_argument('--lt-pruned-ckpt', type=str, action='store',
help='Checkpoint file of the pruned (but not thinned) network')
return argparser
def extract_lottery_ticket(args, untrained_ckpt_name, pruned_ckpt_name):
untrained_ckpt = apputils.load_checkpoint(model=None, chkpt_file=untrained_ckpt_name, model_device='cpu')
untrained_model, _, optimizer, start_epoch = untrained_ckpt
pruned_ckpt = apputils.load_checkpoint(model=None, chkpt_file=pruned_ckpt_name, model_device='cpu')
pruned_model, pruned_scheduler, optimizer, start_epoch = pruned_ckpt
# create a dictionary of masks by inferring the masks from the parameter sparsity
masks_dict = {pname: (torch.ne(param, 0)).type(param.type())
for pname, param in pruned_model.named_parameters()
if pname in pruned_scheduler.zeros_mask_dict.keys()}
for pname, mask in masks_dict.items():
untrained_model.state_dict()[pname].mul_(mask)
sparsities = {pname: distiller.sparsity(mask) for pname, mask in masks_dict.items()}
print(sparsities)
pruned_scheduler.init_from_masks_dict(masks_dict)
apputils.save_checkpoint(0, pruned_model.arch, untrained_model, optimizer=optimizer,
scheduler=pruned_scheduler,
name='_'.join([untrained_ckpt_name, 'masked']))
# pruned_ckpt = torch.load(pruned_ckpt_name, map_location='cpu')
#
# assert 'extras' in pruned_ckpt and pruned_ckpt['extras']
# print("\nContents of Checkpoint['extras']:")
# print(get_contents_table(pruned_ckpt['extras']))
# masks_dict = pruned_ckpt["extras"]["creation_masks"]
# print_sparsities(masks_dict)
# compression_scheduler = distiller.CompressionScheduler(model)
# compression_scheduler.load_state_dict(masks_dict)
#
# model = distiller.models.create_model(False, args.dataset, args.arch, device_ids=[-1])
# model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device)
#
# apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=scheduler,
# name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name,
# dir=msglogger.logdir, extras={'quantized_top1': top1})
msglogger = logging.getLogger()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Lottery Ticket Hypothesis')
add_args(parser)
args = parser.parse_args()
extract_lottery_ticket(args, args.lt_untrained_ckpt, args.lt_pruned_ckpt)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment