diff --git a/README.md b/README.md
index 466a8b80791962ac93bf5e91c5cd9d2f84ab348f..31663d2feb7e3369dccf6b3529d2790076e367de 100755
--- a/README.md
+++ b/README.md
@@ -164,6 +164,9 @@ Beware.
   - Training with [knowledge distillation](https://nervanasystems.github.io/distiller/knowledge_distillation.html), in conjunction with the other available pruning / regularization / quantization methods.
 * **Conditional computation**
   - Sample implementation of Early Exit
+* **Low rank decomposition**
+  - Sample implementation of [truncated SVD](https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb)
+* Lottery Ticket Hypothesis training 
 * Export statistics summaries using Pandas dataframes, which makes it easy to slice, query, display and graph the data.
 * A set of [Jupyter notebooks](https://nervanasystems.github.io/distiller/jupyter/index.html) to plan experiments and analyze compression results.  The graphs and visualizations you see on this page originate from the included Jupyter notebooks.  
   + Take a look at [this notebook](https://github.com/NervanaSystems/distiller/blob/master/jupyter/alexnet_insights.ipynb), which compares visual aspects of dense and sparse Alexnet models.
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index cad734bae25687ed1bd83bf1bcd395e4adc3e248..0385d0eee9e451ee699de01ac019abb26253d662 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -21,6 +21,8 @@ import distiller.models as models
 
 
 def add_cmdline_args(parser):
+    parser.add_argument('--save-untrained-model',  action='store_true', default=False,
+                        help='Save the randomly-initialized model before training (useful for lottery-ticket method)')
     parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
                         help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
     parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
diff --git a/examples/lottery_ticket/README.md b/examples/lottery_ticket/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b9497391896b6570a283f424591b94cddc7bf5a1
--- /dev/null
+++ b/examples/lottery_ticket/README.md
@@ -0,0 +1,52 @@
+## Lottery Ticket Hypothesis
+
+>The Lottery Ticket Hypothesis: A randomly-initialized, dense neural network contains a subnetwork that is initialized
+such that — when trained in isolation — it can match the test accuracy of the original network after training for at
+most the same number of iterations."
+
+### Finding winning tickets
+> We identify winning tickets by training networks and subsequently pruning
+their smallest-magnitude weights. The set of connections that survives this process is the architecture
+of a winning ticket. Unique to our work, the winning ticket’s weights are the values to which these
+connections were initialized before training. This forms our central experiment:
+>1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0).
+>2. Train the network for j iterations, reaching parameters theta_j .
+>3. Prune s% of the parameters, creating a mask m where Pm = (100 - s)%.
+>4. To extract the winning ticket, reset the remaining parameters to their values intheta_0, creating
+the untrained network f(x;m *theta_0).
+
+### Example
+Train a ResNet20-CIFAR10 network from scratch, and save the untrained, randomized initial network weights in a checkpoint file.
+To do this, you use the `--save-untrained-model` flag: <br>
+```bash
+python3 compress_classifier.py --arch resnet20_cifar  ${CIFAR10_PATH} -p=50 --epochs=110 --compress=../ssl/resnet20_cifar_baseline_training.yaml --vs=0 --gpus=0 -j=4 --lr=0.4 --name=resnet20 --save-untrained-model
+``` 
+
+After training the network, we have two outputs: the best trained network (`resnet20_best.pth.tar`) and the initial untrained network (`resnet20_untrained_checkpoint.pth.tar`).<br>
+In this example, we copy them into the `examples/lottery_ticket` directory for convenience.
+
+```bash
+cp logs/resnet20___2019.08.22-220243/resnet20_best.pth.tar ../lottery_ticket/
+cp logs/resnet20___2019.08.22-220243/resnet20_untrained_checkpoint.pth.tar ../lottery_ticket/
+```
+
+We then prune our best trained ResNet20 network and copy the result into `examples/lottery_ticket` as well.
+```bash
+python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml  --resume-from=../lottery_ticket/resnet20_best.pth.tar --vs=0 --reset-optimizer --gpus=0
+cp logs/2019.08.22-222752/best.pth.tar ../lottery_ticket/resnet20_pruned.pth.tar
+```
+
+Next, we run ```lottery.py``` to extract the winning ticket.
+```bash
+python lottery.py --lt-untrained-ckpt=resnet20_untrained_checkpoint.pth.tar --lt-pruned-ckpt=resnet20_pruned.pth.tar
+```
+
+Finally, we train the winning ticket.
+```bash
+python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180  --resume-from=../lottery_ticket/resnet20_untrained_checkpoint.pth.tar_lottery_checkpoint.pth.tar --vs=0 --reset-optimizer --gpus=0 --compress=../ssl/resnet20_cifar_baseline_training.yaml
+```
+
+
+[1] Jonathan Frankle, Michael Carbin<br>
+    The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks<br>
+    arXiv:1803.03635
\ No newline at end of file
diff --git a/examples/lottery_ticket/__init__.py b/examples/lottery_ticket/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/examples/lottery_ticket/lottery.py b/examples/lottery_ticket/lottery.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e61545edfe14d2d4bddac121c792be5b99d0661
--- /dev/null
+++ b/examples/lottery_ticket/lottery.py
@@ -0,0 +1,95 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Lottery Ticket Hypothesis"""
+
+
+import logging
+import distiller.apputils as apputils
+import distiller.models
+import torch
+import argparse
+from tabulate import tabulate
+import distiller
+
+
+def print_sparsities(masks_dict):
+    mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
+                       if mask is not None]
+    print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
+
+
+def add_args(argparser):
+    """
+    Helper function which defines command-line arguments specific to Lottery Ticket Hypothesis training.
+
+    Arguments:
+        argparser (argparse.ArgumentParser): Existing parser to which to add the arguments
+    """
+    group = argparser.add_argument_group('AutoML Compression Arguments')
+    group.add_argument('--lt-untrained-ckpt', type=str, action='store',
+                        help='Checkpoint file of the untrained network (randomly initialized)')
+    group.add_argument('--lt-pruned-ckpt', type=str, action='store',
+                        help='Checkpoint file of the pruned (but not thinned) network')
+    return argparser
+
+
+def extract_lottery_ticket(args, untrained_ckpt_name, pruned_ckpt_name):
+    untrained_ckpt = apputils.load_checkpoint(model=None, chkpt_file=untrained_ckpt_name, model_device='cpu')
+    untrained_model, _, optimizer, start_epoch = untrained_ckpt
+
+    pruned_ckpt = apputils.load_checkpoint(model=None, chkpt_file=pruned_ckpt_name, model_device='cpu')
+    pruned_model, pruned_scheduler, optimizer, start_epoch = pruned_ckpt
+
+    # create a dictionary of masks by inferring the masks from the parameter sparsity
+    masks_dict = {pname: (torch.ne(param, 0)).type(param.type())
+                  for pname, param in pruned_model.named_parameters()
+                  if pname in pruned_scheduler.zeros_mask_dict.keys()}
+    for pname, mask in masks_dict.items():
+        untrained_model.state_dict()[pname].mul_(mask)
+
+    sparsities = {pname: distiller.sparsity(mask) for pname, mask in masks_dict.items()}
+    print(sparsities)
+    pruned_scheduler.init_from_masks_dict(masks_dict)
+
+    apputils.save_checkpoint(0, pruned_model.arch, untrained_model, optimizer=optimizer,
+                             scheduler=pruned_scheduler,
+                             name='_'.join([untrained_ckpt_name, 'masked']))
+
+    # pruned_ckpt = torch.load(pruned_ckpt_name, map_location='cpu')
+    #
+    # assert 'extras' in pruned_ckpt and pruned_ckpt['extras']
+    # print("\nContents of Checkpoint['extras']:")
+    # print(get_contents_table(pruned_ckpt['extras']))
+    # masks_dict = pruned_ckpt["extras"]["creation_masks"]
+    # print_sparsities(masks_dict)
+    # compression_scheduler = distiller.CompressionScheduler(model)
+    # compression_scheduler.load_state_dict(masks_dict)
+    #
+    # model = distiller.models.create_model(False, args.dataset, args.arch, device_ids=[-1])
+    # model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device)
+    #
+    # apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=scheduler,
+    #                      name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name,
+    #                      dir=msglogger.logdir, extras={'quantized_top1': top1})
+
+msglogger = logging.getLogger()
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Lottery Ticket Hypothesis')
+    add_args(parser)
+    args = parser.parse_args()
+    extract_lottery_ticket(args, args.lt_untrained_ckpt, args.lt_pruned_ckpt)
\ No newline at end of file