diff --git a/.gitignore b/.gitignore
index a1670e9fddf038bfdb2c6a14f0de7843322bba9d..110302734953a65f823713dda57f4f2f53b23a9c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,8 @@ __pycache__/
 .pytest_cache
 .cache
 pytest_collaterals/
+examples/ncf/run/
+examples/ncf/ml-20m*
 
 # GNMT sample
 examples/GNMT/data
diff --git a/distiller/modules/tsvd.py b/distiller/modules/tsvd.py
new file mode 100755
index 0000000000000000000000000000000000000000..aea20bc2e3726bf2dd8104926cd42a873f9bc56f
--- /dev/null
+++ b/distiller/modules/tsvd.py
@@ -0,0 +1,67 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Truncated-SVD module.
+
+For an example of how truncated-SVD can be used, see this Jupyter notebook:
+https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb
+
+"""
+
+def truncated_svd(W, l):
+    """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD.
+
+    For the original implementation (MIT license), see Faster-RCNN:
+    https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
+    We replaced numpy operations with pytorch operations (so that we can leverage the GPU).
+
+    Arguments:
+      W: N x M weights matrix
+      l: number of singular values to retain
+    Returns:
+      Ul, L: matrices such that W \approx Ul*L
+    """
+
+    U, s, V = torch.svd(W, some=True)
+
+    Ul = U[:, :l]
+    sl = s[:l]
+    V = V.t()
+    Vl = V[:l, :]
+
+    SV = torch.mm(torch.diag(sl), Vl)
+    return Ul, SV
+
+
+class TruncatedSVD(nn.Module):
+    def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
+        super().__init__()
+        self.replaced_gemm = replaced_gemm
+        print("W = {}".format(gemm_weights.shape))
+        self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))
+        print("U = {}".format(self.U.shape))
+
+        self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
+        self.fc_u.weight.data = self.U
+
+        print("SV = {}".format(self.SV.shape))
+        self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
+        self.fc_sv.weight.data = self.SV#.t()
+
+    def forward(self, x):
+        x = self.fc_sv.forward(x)
+        x = self.fc_u.forward(x)
+        return x
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 56454e7911d03cf60189c7695160401eaef20ca9..96b8dceb70956320c6fc67380ec85ad00cedec21 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -446,8 +446,7 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner):
         if fraction_to_prune == 0:
             return
         binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
-                                                 zeros_mask_dict, model, binary_map,
-                                                 self.rounding_fn)
+                                                 zeros_mask_dict, model, binary_map)
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py
index 06d5d9a31233798b1c4b6112a6d38cc73e0c18e8..e24c97686556d708d111f4b3bc7023bcff5985e9 100644
--- a/distiller/quantization/__init__.py
+++ b/distiller/quantization/__init__.py
@@ -16,7 +16,7 @@
 
 from .quantizer import Quantizer
 from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \
-    LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args,\
+    LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
     RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode
 from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
 
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 7bbedc42644e097c18f51f6458df7bddbab63c53..e632cef7259df7dc66357f8dceae952f691ad44e 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -1355,3 +1355,21 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
                                                                   per_channel=perch)
             m.register_buffer(ptq.q_attr_name + '_scale', torch.ones_like(scale))
             m.register_buffer(ptq.q_attr_name + '_zero_point', torch.zeros_like(zero_point))
+
+
+class NCFQuantAwareTrainQuantizer(QuantAwareTrainRangeLinearQuantizer):
+    def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_bias=32,
+                 overrides=None, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False):
+        super(NCFQuantAwareTrainQuantizer, self).__init__(model, optimizer=optimizer,
+                                                          bits_activations=bits_activations,
+                                                          bits_weights=bits_weights,
+                                                          bits_bias=bits_bias,
+                                                          overrides=overrides,
+                                                          mode=mode, ema_decay=ema_decay,
+                                                          per_channel_wts=per_channel_wts,
+                                                          quantize_inputs=False)
+
+        self.replacement_factory[distiller.modules.EltwiseMult] = self.activation_replace_fn
+        self.replacement_factory[distiller.modules.Concat] = self.activation_replace_fn
+        self.replacement_factory[nn.Linear] = self.activation_replace_fn
+        # self.replacement_factory[nn.Sigmoid] = self.activation_replace_fn
diff --git a/examples/ncf/MLPERF_LICENSE.md b/examples/ncf/MLPERF_LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..89ba5be4833f1b8418c869f63ab8ee1f6e881531
--- /dev/null
+++ b/examples/ncf/MLPERF_LICENSE.md
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright 2018 The MLPerf Authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/examples/ncf/README.md b/examples/ncf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..eac9b0c5ed7963c996019a7899c9eb0d7c57dfbc
--- /dev/null
+++ b/examples/ncf/README.md
@@ -0,0 +1,172 @@
+# NCF - Neural Collaborative Filtering
+
+The NCF implementation provided here is based on the implementation found in the MLPerf Training GitHub repository.
+This sample is not based on the latest implementation in MLPerf, it is based on an earlier revision which uses the ml-20m dataset. The latest code uses a much larger dataset. We plan to move to the latest version in the near future.  
+You can fine the revision this sample is based on [here](https://github.com/mlperf/training/tree/fe17e837ed12974d15c86d5173fe8f2c188434d5/recommendation/pytorch).
+
+We've made several modifications to the code:
+* Removed all MLPerf specific code including logging
+* In `ncf.py`:
+  * Added calls to Distiller compression APIs
+  * Added progress indication in training and evaluation flows
+* In `neumf.py`:
+  * Added option to split final the FC layer (the `split_final` parameter). See [below](#side-note-splitting-the-final-fc-layer).
+  * Replaced all functional calls with modules so they can be detected by Distiller, as per this [guide](https://nervanasystems.github.io/distiller/prepare_model_quant.html) in the Distiller docs.
+* In `dataset.py`:
+  * Speed up data loading - On first data will is loaded from CSVs and then pickled. On subsequent runs the pickle is loaded. This is much faster than the original implementation, but still very slow.
+  * Added progress indication during data load process
+
+The sample command lines provided [below](#running-the-sample) focus on **post-training quantization**. We did integrate the capability to run quantization-aware training into `ncf.py`. We'll add examples for this at a later time.
+
+## Problem
+
+This task benchmarks recommendation with implicit feedback on the [MovieLens 20 Million (ml-20m) dataset](https://grouplens.org/datasets/movielens/20m/) with a [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569) model.
+The model trains on binary information about whether or not a user interacted with a specific item.
+
+## Setup
+
+### Steps to configure machine
+
+* Install `unzip` and `curl`
+
+  ```bash
+  sudo apt-get install unzip curl
+  ```
+
+* Make sure the latest Distiller requirements are installed
+
+  ```bash
+  # Relative to this sample directory
+  cd <distiller-repo-root>
+  pip install -e .
+  ```
+
+* Download and verify data
+
+  ```bash
+  cd <distiller-repo-root>/examples/ncf
+  # Creates ml-20.zip
+  source ../download_dataset.sh
+  # Confirms the MD5 checksum of ml-20.zip
+  source ../verify_dataset.sh
+  ```
+
+## Running the Sample
+
+### Train a Base FP32 Model
+
+We train a model with the following parameters:
+
+* MLP Side
+  * Embedding size per user / item: 128
+  * FC layer sizes: 256x256 --> 256x128 --> 128x64
+* MF (matrix factorization) Side
+  * Embedding size per user / item: 64
+* Therefore, the final FC layer size is: 128x1
+
+Adam optimizer is used, with an initial learning rate of 0.0005. Batch size is 2048. Convergence is obtained after 7 epochs.
+
+```bash
+python ncf.py ml-20m -l 0.0005 -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --processes 10 -o run/neumf/base_fp32
+...
+Epoch 0 Loss 0.1179 (0.1469): 100%|█████████████████████████████| 48491/48491 [07:04<00:00, 114.23it/s]
+Epoch 0 evaluation
+Epoch 0: HR@10 = 0.5738, NDCG@10 = 0.3367, AvgTrainLoss = 0.1469, train_time = 424.52, val_time = 47.04
+...
+Epoch 6 Loss 0.0914 (0.0943): 100%|█████████████████████████████| 48491/48491 [06:47<00:00, 118.90it/s]
+Epoch 6 evaluation
+Epoch 6: HR@10 = 0.6355, NDCG@10 = 0.3820, AvgTrainLoss = 0.0943, train_time = 407.84, val_time = 62.99
+```
+
+The hit-rate of the base model is 63.55.
+
+### Side-Note: Splitting the Final FC Layer
+
+As mentioned above, we added an option to split the final FC layer of the model (the `split_final` parameter in `NeuMF.__init__`).
+
+The reasoning behind this is that the input to the final FC layer in NCF is a concatenation of the outputs of the MLP and MF "branches". These outputs have very different dynamic ranges.  
+In the model we just trained, the MLP branch output range is [0 .. 203] while the MF branch output range is [-6.3 .. 7.4]. When doing quantized concatenation, we have to accommodate the larger range, which leads to a large quantization error for the data that came from the MF branch. When quantizing to 8-bits, the MF branch will cover only 10 bins out of the 256 bins, which means just over 3-bits.  
+The mitigation we use is to split the final FC layer as follows:
+
+```
+  Before Split:            After Split:
+  -------------            ------------
+  MF_OUT  MLP_OUT          MF_OUT  MLP_OUT
+    \        /               |        |
+     \      /      --->    MF_FC   MLP_FC
+      CONCAT                 \        /
+        |                     \      /
+     FINAL_FC                  \    /
+                                ADD
+```
+After splitting, the two inputs to the add operation have ranges [-283 .. 40] from the MLP side and [-54 .. 47] from the MF side. While the problem isn't completely solved, it's much better than before. Now the MF covers 126 bins, which is almost 7-bits.
+
+Note that in FP32 the 2 modes are functionally identical. The split final option is for evaluation only, and we take care to convert the model trained without splitting into a split model when loading the checkpoint. 
+
+### Collect Quantization Stats for Post-Training Quantization
+
+We generated stats for both the non-split and split case. These are the `quantization_stats_no_split.yaml` and `quantization_stats_split.yaml` files in the example folder.
+
+For reference, the command lines used to generate these are:
+
+```bash
+python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1
+python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1 --split-final
+```
+Note that `--qe-calibration 0.1` means that we use 10% of the test dataset for the stats collection.
+
+### Post-Training Quantization Experiments
+
+We'll use the following settings for quantization:
+
+* 8-bits for weights and activations: `--qeba 8 --qebw 8`
+* Asymmetric: `--qem asym_u`
+* Per-channel: `--qepc`
+
+Let's see the difference splitting the final FC layer makes in terms of overall accuracy:
+
+```bash
+ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --qe-stats-file quantization_stats_no_split.yaml
+...
+Initial HR@10 = 0.4954, NDCG@10 = 0.2802, val_time = 521.11
+```
+
+```bash
+ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --split-final --qe-stats-file quantization_stats_split.yaml
+...
+HR@10 = 0.6278, NDCG@10 = 0.3760, val_time = 601.87
+```
+
+We can see that without splitting, we get ~14% degradation in hit-rate. With splitting we gain almost all of the accuracy back, with about 0.8% degradation.
+
+## Dataset / Environment
+
+### Publication / Attribution
+
+Harper, F. M. & Konstan, J. A. (2015), 'The MovieLens Datasets: History and Context', ACM Trans. Interact. Intell. Syst. 5(4), 19:1--19:19.
+
+### Data preprocessing
+
+1. Unzip
+2. Remove users with less than 20 reviews
+3. Create training and test data separation described below
+
+### Training and test data separation
+
+Positive training examples are all but the last item each user rated.
+Negative training examples are randomly selected from the unrated items for each user.
+
+The last item each user rated is used as a positive example in the test set.
+A fixed set of 999 unrated items are also selected to calculate hit rate at 10 for predicting the test item.
+
+### Training data order
+
+Data is traversed randomly with 4 negative examples selected on average for every positive example.
+
+## Model
+
+### Publication/Attribution
+
+Xiangnan He, Lizi Liao, Hanwang Zhang, Liqiang Nie, Xia Hu and Tat-Seng Chua (2017). [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569). In Proceedings of WWW '17, Perth, Australia, April 03-07, 2017.
+
+The author's original code is available at [hexiangnan/neural_collaborative_filtering](https://github.com/hexiangnan/neural_collaborative_filtering).
diff --git a/examples/ncf/convert.py b/examples/ncf/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8e5d7ccd33ffaecaa0a63e3fd093386d70c7fb
--- /dev/null
+++ b/examples/ncf/convert.py
@@ -0,0 +1,101 @@
+import os
+from argparse import ArgumentParser
+from collections import defaultdict
+
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+from load import implicit_load
+
+
+MIN_RATINGS = 20
+
+
+USER_COLUMN = 'user_id'
+ITEM_COLUMN = 'item_id'
+
+
+TRAIN_RATINGS_FILENAME = 'train-ratings.csv'
+TEST_RATINGS_FILENAME = 'test-ratings.csv'
+TEST_NEG_FILENAME = 'test-negative.csv'
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument('path', type=str,
+                        help='Path to reviews CSV file from MovieLens')
+    parser.add_argument('output', type=str,
+                        help='Output directory for train and test CSV files')
+    parser.add_argument('-n', '--negatives', type=int, default=999,
+                        help='Number of negative samples for each positive'
+                             'test example')
+    parser.add_argument('-s', '--seed', type=int, default=0,
+                        help='Random seed to reproduce same negative samples')
+    return parser.parse_args()
+
+
+def main():
+    args = parse_args()
+    np.random.seed(args.seed)
+
+    print("Loading raw data from {}".format(args.path))
+    df = implicit_load(args.path, sort=False)
+    print("Filtering out users with less than {} ratings".format(MIN_RATINGS))
+    grouped = df.groupby(USER_COLUMN)
+    df = grouped.filter(lambda x: len(x) >= MIN_RATINGS)
+
+    print("Mapping original user and item IDs to new sequential IDs")
+    original_users = df[USER_COLUMN].unique()
+    original_items = df[ITEM_COLUMN].unique()
+
+    user_map = {user: index for index, user in enumerate(original_users)}
+    item_map = {item: index for index, item in enumerate(original_items)}
+
+    df[USER_COLUMN] = df[USER_COLUMN].apply(lambda user: user_map[user])
+    df[ITEM_COLUMN] = df[ITEM_COLUMN].apply(lambda item: item_map[item])
+
+    assert df[USER_COLUMN].max() == len(original_users) - 1
+    assert df[ITEM_COLUMN].max() == len(original_items) - 1
+
+    print("Creating list of items for each user")
+    # Need to sort before popping to get last item
+    df.sort_values(by='timestamp', inplace=True)
+    all_ratings = set(zip(df[USER_COLUMN], df[ITEM_COLUMN]))
+    user_to_items = defaultdict(list)
+    for row in tqdm(df.itertuples(), desc='Ratings', total=len(df)):
+        user_to_items[getattr(row, USER_COLUMN)].append(getattr(row, ITEM_COLUMN))  # noqa: E501
+
+    test_ratings = []
+    test_negs = []
+    all_items = set(range(len(original_items)))
+    print("Generating {} negative samples for each user"
+          .format(args.negatives))
+    for user in tqdm(range(len(original_users)), desc='Users', total=len(original_users)):  # noqa: E501
+        test_item = user_to_items[user].pop()
+
+        all_ratings.remove((user, test_item))
+        all_negs = all_items - set(user_to_items[user])
+        all_negs = sorted(list(all_negs))  # determinism
+
+        test_ratings.append((user, test_item))
+        test_negs.append(list(np.random.choice(all_negs, args.negatives)))
+
+    print("Saving train and test CSV files to {}".format(args.output))
+    df_train_ratings = pd.DataFrame(list(all_ratings))
+    df_train_ratings['fake_rating'] = 1
+    df_train_ratings.to_csv(os.path.join(args.output, TRAIN_RATINGS_FILENAME),
+                            index=False, header=False, sep='\t')
+
+    df_test_ratings = pd.DataFrame(test_ratings)
+    df_test_ratings['fake_rating'] = 1
+    df_test_ratings.to_csv(os.path.join(args.output, TEST_RATINGS_FILENAME),
+                           index=False, header=False, sep='\t')
+
+    df_test_negs = pd.DataFrame(test_negs)
+    df_test_negs.to_csv(os.path.join(args.output, TEST_NEG_FILENAME),
+                        index=False, header=False, sep='\t')
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/ncf/dataset.py b/examples/ncf/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69fd6ceca410c94fb382aa8889e8e118d2ee985
--- /dev/null
+++ b/examples/ncf/dataset.py
@@ -0,0 +1,132 @@
+import numpy as np
+import scipy
+import scipy.sparse
+import torch
+import torch.utils.data
+import subprocess
+import time
+from tqdm import tqdm
+import os
+import pickle
+import logging
+
+msglogger = logging.getLogger()
+
+
+def wccount(filename):
+    out = subprocess.Popen(['wc', '-l', filename],
+                           stdout=subprocess.PIPE,
+                           stderr=subprocess.STDOUT
+                           ).communicate()[0]
+    return int(out.partition(b' ')[0])
+
+
+class TimingContext(object):
+    def __init__(self, desc):
+        self.desc = desc
+
+    def __enter__(self):
+        msglogger.info(self.desc + ' ... ')
+        self.start = time.time()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        end = time.time()
+        msglogger.info('Done in {0:.4f} seconds'.format(end - self.start))
+        return True
+
+
+class CFTrainDataset(torch.utils.data.dataset.Dataset):
+    def __init__(self, train_fname, nb_neg):
+        self._load_train_matrix(train_fname)
+        self.nb_neg = nb_neg
+
+    def _load_train_matrix(self, train_fname):
+        pkl_name = os.path.splitext(train_fname)[0] + '_data.pkl'
+        npz_name = os.path.splitext(train_fname)[0] + '_mat.npz'
+
+        if os.path.isfile(pkl_name) and os.path.isfile(npz_name):
+            msglogger.info('Found saved dataset data structures')
+            with TimingContext('Loading data list pickle'), open(pkl_name, 'rb') as f:
+                self.data = pickle.load(f)
+            with TimingContext('Loading matrix npz'):
+                self.mat = scipy.sparse.dok_matrix(scipy.sparse.load_npz(npz_name))
+            self.nb_users = self.mat.shape[0]
+            self.nb_items = self.mat.shape[1]
+        else:
+            def process_line(line):
+                tmp = line.split('\t')
+                return [int(tmp[0]), int(tmp[1]), float(tmp[2]) > 0]
+
+            with TimingContext('Loading CSV file'), open(train_fname, 'r') as file:
+                data = list(map(process_line, tqdm(file, total=wccount(train_fname))))
+
+            with TimingContext('Calculating min/max'):
+                self.nb_users = max(data, key=lambda x: x[0])[0] + 1
+                self.nb_items = max(data, key=lambda x: x[1])[1] + 1
+
+            with TimingContext('Constructing data list'):
+                self.data = list(filter(lambda x: x[2], data))
+
+            with TimingContext('Saving data list pickle'), open(pkl_name, 'wb') as f:
+                pickle.dump(self.data, f)
+
+            with TimingContext('Building dok matrix'):
+                self.mat = scipy.sparse.dok_matrix(
+                        (self.nb_users, self.nb_items), dtype=np.float32)
+                for user, item, _ in tqdm(data):
+                    self.mat[user, item] = 1.
+
+            with TimingContext('Converting to COO matrix and saving'):
+                scipy.sparse.save_npz(npz_name, self.mat.tocoo(copy=True))
+
+    def __len__(self):
+        return (self.nb_neg + 1) * len(self.data)
+
+    def __getitem__(self, idx):
+        if idx % (self.nb_neg + 1) == 0:
+            idx = idx // (self.nb_neg + 1)
+            return self.data[idx][0], self.data[idx][1], np.ones(1, dtype=np.float32)  # noqa: E501
+        else:
+            idx = idx // (self.nb_neg + 1)
+            u = self.data[idx][0]
+            j = torch.LongTensor(1).random_(0, self.nb_items).item()
+            while (u, j) in self.mat:
+                j = torch.LongTensor(1).random_(0, self.nb_items).item()
+            return u, j, np.zeros(1, dtype=np.float32)
+
+
+def load_test_ratings(fname):
+    pkl_name = os.path.splitext(fname)[0] + '.pkl'
+    if os.path.isfile(pkl_name):
+        with TimingContext('Found test rating pickle file - loading'), open(pkl_name, 'rb') as f:
+            res = pickle.load(f)
+    else:
+        def process_line(line):
+            tmp = map(int, line.split('\t')[0:2])
+            return list(tmp)
+        with TimingContext('Loading test ratings from csv'), open(fname, 'r') as f:
+            ratings = map(process_line, tqdm(f, total=wccount(fname)))
+            res = list(ratings)
+        with TimingContext('Saving test ratings list pickle'), open(pkl_name, 'wb') as f:
+            pickle.dump(res, f)
+
+    return res
+
+
+def load_test_negs(fname):
+    pkl_name = os.path.splitext(fname)[0] + '.pkl'
+    if os.path.isfile(pkl_name):
+        with TimingContext('Found test negatives pickle file - loading'), open(pkl_name, 'rb') as f:
+            res = pickle.load(f)
+    else:
+        def process_line(line):
+            tmp = map(int, line.split('\t'))
+            return list(tmp)
+        with TimingContext('Loading test negatives from csv'), open(fname, 'r') as f:
+            negs = map(process_line, tqdm(f, total=wccount(fname)))
+            res = list(negs)
+        with TimingContext('Saving test negatives list pickle'), open(pkl_name, 'wb') as f:
+            pickle.dump(res, f)
+
+    return res
diff --git a/examples/ncf/download_dataset.sh b/examples/ncf/download_dataset.sh
new file mode 100755
index 0000000000000000000000000000000000000000..8876230f14123358e895595d0126cd9537733908
--- /dev/null
+++ b/examples/ncf/download_dataset.sh
@@ -0,0 +1,16 @@
+function download_20m {
+	echo "Download ml-20m"
+	curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip
+}
+
+function download_1m {
+	echo "Downloading ml-1m"
+	curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip
+}
+
+if [[ $1 == "ml-1m" ]]
+then
+	download_1m
+else
+	download_20m
+fi
diff --git a/examples/ncf/load.py b/examples/ncf/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..304f43c2bf9836e212a29fa29fb8820320b460d6
--- /dev/null
+++ b/examples/ncf/load.py
@@ -0,0 +1,68 @@
+from collections import namedtuple
+
+import pandas as pd
+
+
+RatingData = namedtuple('RatingData',
+                        ['items', 'users', 'ratings', 'min_date', 'max_date'])
+
+
+def describe_ratings(ratings):
+    info = RatingData(items=len(ratings['item_id'].unique()),
+                      users=len(ratings['user_id'].unique()),
+                      ratings=len(ratings),
+                      min_date=ratings['timestamp'].min(),
+                      max_date=ratings['timestamp'].max())
+    print("{ratings} ratings on {items} items from {users} users"
+          " from {min_date} to {max_date}"
+          .format(**(info._asdict())))
+    return info
+
+
+def process_movielens(ratings, sort=True):
+    ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s')
+    if sort:
+        ratings.sort_values(by='timestamp', inplace=True)
+    describe_ratings(ratings)
+    return ratings
+
+
+def load_ml_100k(filename, sort=True):
+    names = ['user_id', 'item_id', 'rating', 'timestamp']
+    ratings = pd.read_csv(filename, sep='\t', names=names)
+    return process_movielens(ratings, sort=sort)
+
+
+def load_ml_1m(filename, sort=True):
+    names = ['user_id', 'item_id', 'rating', 'timestamp']
+    ratings = pd.read_csv(filename, sep='::', names=names, engine='python')
+    return process_movielens(ratings, sort=sort)
+
+
+def load_ml_10m(filename, sort=True):
+    names = ['user_id', 'item_id', 'rating', 'timestamp']
+    ratings = pd.read_csv(filename, sep='::', names=names, engine='python')
+    return process_movielens(ratings, sort=sort)
+
+
+def load_ml_20m(filename, sort=True):
+    ratings = pd.read_csv(filename)
+    ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s')
+    names = {'userId': 'user_id', 'movieId': 'item_id'}
+    ratings.rename(columns=names, inplace=True)
+    return process_movielens(ratings, sort=sort)
+
+
+DATASETS = [k.replace('load_', '') for k in locals().keys() if "load_" in k]
+
+
+def get_dataset_name(filename):
+    for dataset in DATASETS:
+        if dataset in filename.replace('-', '_').lower():
+            return dataset
+    raise NotImplementedError
+
+
+def implicit_load(filename, sort=True):
+    func = globals()["load_" + get_dataset_name(filename)]
+    return func(filename, sort=sort)
diff --git a/examples/ncf/logging.conf b/examples/ncf/logging.conf
new file mode 100755
index 0000000000000000000000000000000000000000..8db92a75fccb779dc2c02b8e7c668d3cf24363c4
--- /dev/null
+++ b/examples/ncf/logging.conf
@@ -0,0 +1,38 @@
+[formatters]
+keys: simple, time_simple
+
+[handlers]
+keys: console, file
+
+[loggers]
+keys: root, app_cfg
+
+[formatter_simple]
+format: %(message)s
+
+[formatter_time_simple]
+format: %(asctime)s - %(message)s
+
+[handler_console]
+class: StreamHandler
+propagate: 0
+args: []
+formatter: simple
+
+[handler_file]
+class: FileHandler
+mode: 'w'
+args=('%(logfilename)s', 'w')
+formatter: time_simple
+
+[logger_root]
+level: INFO
+propagate: 1
+handlers: console, file
+
+[logger_app_cfg]
+# Use this logger to log the application configuration and execution environment
+level: DEBUG
+qualname: app_cfg
+propagate: 0
+handlers: file
diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py
new file mode 100644
index 0000000000000000000000000000000000000000..98376ca825ae800cbfbc8aeae08df176b765d174
--- /dev/null
+++ b/examples/ncf/ncf.py
@@ -0,0 +1,471 @@
+import os
+import heapq
+import math
+import time
+from functools import partial
+from datetime import datetime
+from collections import OrderedDict
+from argparse import ArgumentParser
+import sys
+
+import tqdm
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import multiprocessing as mp
+
+import utils
+from neumf import NeuMF
+from dataset import CFTrainDataset, load_test_ratings, load_test_negs
+from convert import (TEST_NEG_FILENAME, TEST_RATINGS_FILENAME,
+                     TRAIN_RATINGS_FILENAME)
+
+import distiller
+import distiller.quantization as quantization
+import distiller.apputils as apputils
+from distiller.data_loggers import TensorBoardLogger, PythonLogger
+
+msglogger = None
+
+
+def parse_args():
+    parser = ArgumentParser(description="Train a Nerual Collaborative"
+                                        " Filtering model")
+    parser.add_argument('data', type=str,
+                        help='path to test and training data files')
+    parser.add_argument('-e', '--epochs', type=int, default=20,
+                        help='number of epochs for training')
+    parser.add_argument('-b', '--batch-size', type=int, default=256,
+                        help='number of examples for each iteration')
+    parser.add_argument('-f', '--factors', type=int, default=8,
+                        help='number of predictive factors')
+    parser.add_argument('--layers', nargs='+', type=int,
+                        default=[64, 32, 16, 8],
+                        help='size of hidden layers for MLP')
+    parser.add_argument('-n', '--negative-samples', type=int, default=4,
+                        help='number of negative examples per interaction')
+    parser.add_argument('-l', '--learning-rate', type=float, default=0.001,
+                        help='learning rate for optimizer')
+    parser.add_argument('-k', '--topk', type=int, default=10,
+                        help='rank for test examples to be considered a hit')
+    parser.add_argument('--no-cuda', action='store_true',
+                        help='use available GPUs')
+    parser.add_argument('--seed', '-s', type=int,
+                        help='manually set random seed for torch')
+    parser.add_argument('--threshold', '-t', type=float,
+                        help='stop training early at threshold')
+    parser.add_argument('--processes', '-p', type=int, default=1,
+                        help='Number of processes for evaluating model')
+    parser.add_argument('--workers', '-w', type=int, default=8,
+                        help='Number of workers for training DataLoader')
+
+    # Distiller Args
+    # summary_choices = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
+    # parser.add_argument('--summary', type=str, choices=summary_choices,
+    #                     help='print a summary of the model, and exit - options: ' +
+    #                          ' | '.join(summary_choices))
+    parser.add_argument('--load', type=str, metavar='PATH')
+    parser.add_argument('--reset-optimizer', action='store_true')
+    parser.add_argument('--eval', '--evaluate', action='store_true')
+    parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
+                        help='configuration file for pruning the model (default is to use hard-coded schedule)')
+    parser.add_argument('--gpus', metavar='DEV_ID', default=None,
+                        help='Comma-separated list of GPU device IDs to be used '
+                             '(default is to use all available devices)')
+    parser.add_argument('--out-dir', '-o', dest='output_dir', default=os.path.join('run', 'neumf'),
+                        help='Path to dump logs and checkpoints')
+    parser.add_argument('--name', metavar='NAME', default=None, help='Experiment name')
+    parser.add_argument('--log-freq', '--lf', default=100, type=int, metavar='N', help='Logging frequency')
+    parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
+                        help='log the parameter tensors histograms to file '
+                             '(WARNING: this can use significant disk space)')
+    parser.add_argument('--split-final', '--sf', action='store_true')
+    parser.add_argument('--eval-fp16', action='store_true')
+    parser.add_argument('--activation-histograms', '--act-hist',
+                        type=distiller.utils.float_range_argparse_checker(exc_min=True),
+                        metavar='PORTION_OF_TEST_SET',
+                        help='Run the model in evaluation mode on the specified portion of the test dataset and '
+                             'generate activation histograms. NOTE: This slows down evaluation significantly')
+    quantization.add_post_train_quant_args(parser)
+
+    return parser.parse_args()
+
+
+def predict(model, users, items, batch_size=1024, use_cuda=True):
+    with torch.no_grad():
+        batches = [(users[i:i + batch_size], items[i:i + batch_size])
+                   for i in range(0, len(users), batch_size)]
+        preds = []
+        for user, item in batches:
+            def proc(x):
+                x = np.array(x)
+                x = torch.from_numpy(x)
+                if use_cuda:
+                    x = x.cuda(async=True)
+                return torch.autograd.Variable(x)
+            outp = model(proc(user), proc(item), torch.tensor([True], dtype=torch.bool))
+            outp = outp.data.cpu().numpy()
+            preds += list(outp.flatten())
+        return preds
+
+
+def _calculate_hit(ranked, test_item):
+    return int(test_item in ranked)
+
+
+def _calculate_ndcg(ranked, test_item):
+    for i, item in enumerate(ranked):
+        if item == test_item:
+            return math.log(2) / math.log(i + 2)
+    return 0.
+
+
+def eval_one(rating, items, model, K, use_cuda=True):
+    user = rating[0]
+    test_item = rating[1]
+    items.append(test_item)
+    # items.insert(0, test_item)
+    users = [user] * len(items)
+    predictions = predict(model, users, items, use_cuda=use_cuda)
+
+    map_item_score = {item: pred for item, pred in zip(items, predictions)}
+    ranked = heapq.nlargest(K, map_item_score, key=map_item_score.get)
+
+    hit = _calculate_hit(ranked, test_item)
+    ndcg = _calculate_ndcg(ranked, test_item)
+    # return user, hit, ndcg
+    return hit, ndcg
+
+
+def val_epoch(model, ratings, negs, K, use_cuda=True, output=None, epoch=None,
+              processes=1, num_users=-1):
+    if epoch is None:
+        msglogger.info("Initial evaluation")
+    else:
+        msglogger.info("Epoch {} evaluation".format(epoch))
+    start = datetime.now()
+    model.eval()
+
+    if num_users > 0:
+        ratings = ratings[:num_users]
+        negs = negs[:num_users]
+
+    if processes > 1:
+        context = mp.get_context('spawn')
+        _eval_one = partial(eval_one, model=model, K=K, use_cuda=use_cuda)
+        with context.Pool(processes=processes) as workers:
+            hits_and_ndcg = workers.starmap(_eval_one, zip(ratings, negs))
+        hits, ndcgs = zip(*hits_and_ndcg)
+    else:
+        hits, ndcgs = [], []
+        with tqdm.tqdm(zip(ratings, negs), total=len(ratings)) as t:
+            for rating, items in t:
+                hit, ndcg = eval_one(rating, items, model, K, use_cuda=use_cuda)
+                hits.append(hit)
+                ndcgs.append(ndcg)
+                steps_completed = len(hits) + 1
+                if steps_completed % 100 == 0:
+                    t.set_description('HR@10 = {0:.4f}, NDCG = {1:.4f}'.format(np.mean(hits), np.mean(ndcgs)))
+
+    hits = np.array(hits, dtype=np.float32)
+    ndcgs = np.array(ndcgs, dtype=np.float32)
+
+    end = datetime.now()
+    if output is not None:
+        result = OrderedDict()
+        result['timestamp'] = datetime.now()
+        result['duration'] = end - start
+        result['epoch'] = epoch
+        result['K'] = K
+        result['hit_rate'] = np.mean(hits)
+        result['NDCG'] = np.mean(ndcgs)
+        utils.save_result(result, output)
+
+    return hits, ndcgs
+
+
+def main():
+    global msglogger
+
+    script_dir = os.path.dirname(__file__)
+    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
+
+    args = parse_args()
+
+    # Distiller loggers
+    msglogger = apputils.config_pylogger('logging.conf', args.name, output_dir=args.output_dir)
+    tflogger = TensorBoardLogger(msglogger.logdir)
+    # tflogger.log_gradients = True
+    # pylogger = PythonLogger(msglogger)
+
+    if args.seed is not None:
+        msglogger.info("Using seed = {}".format(args.seed))
+        torch.manual_seed(args.seed)
+        np.random.seed(seed=args.seed)
+
+    args.qe_mode = str(args.qe_mode).split('.')[1]
+    args.qe_clip_acts = str(args.qe_clip_acts).split('.')[1]
+
+    apputils.log_execution_env_state(sys.argv, gitroot=module_path)
+
+    if args.gpus is not None:
+        try:
+            args.gpus = [int(s) for s in args.gpus.split(',')]
+        except ValueError:
+            msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
+            exit(1)
+        if len(args.gpus) > 1:
+            msglogger.error('ERROR: Only single GPU supported for NCF')
+            exit(1)
+        available_gpus = torch.cuda.device_count()
+        for dev_id in args.gpus:
+            if dev_id >= available_gpus:
+                msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
+                                .format(dev_id, available_gpus))
+                exit(1)
+        # Set default device in case the first one on the list != 0
+        torch.cuda.set_device(args.gpus[0])
+
+    # Save configuration to file
+    config = {k: v for k, v in args.__dict__.items()}
+    config['timestamp'] = "{:.0f}".format(datetime.utcnow().timestamp())
+    config['local_timestamp'] = str(datetime.now())
+    run_dir = msglogger.logdir
+    msglogger.info("Saving config and results to {}".format(run_dir))
+    if not os.path.exists(run_dir) and run_dir != '':
+        os.makedirs(run_dir)
+    utils.save_config(config, run_dir)
+
+    # Check that GPUs are actually available
+    use_cuda = not args.no_cuda and torch.cuda.is_available()
+
+    t1 = time.time()
+    # Load Data
+    training = not (args.eval or args.qe_calibration or args.activation_histograms)
+    msglogger.info('Loading data')
+    if training:
+        train_dataset = CFTrainDataset(
+            os.path.join(args.data, TRAIN_RATINGS_FILENAME), args.negative_samples)
+        train_dataloader = torch.utils.data.DataLoader(
+            dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
+            num_workers=args.workers, pin_memory=True)
+        nb_users, nb_items = train_dataset.nb_users, train_dataset.nb_items
+    else:
+        train_dataset = None
+        train_dataloader = None
+        nb_users, nb_items = (138493, 26744)
+
+    test_ratings = load_test_ratings(os.path.join(args.data, TEST_RATINGS_FILENAME))  # noqa: E501
+    test_negs = load_test_negs(os.path.join(args.data, TEST_NEG_FILENAME))
+
+    msglogger.info('Load data done [%.1f s]. #user=%d, #item=%d, #train=%s, #test=%d'
+              % (time.time()-t1, nb_users, nb_items, str(train_dataset.mat.nnz) if training else 'N/A',
+                 len(test_ratings)))
+
+    # Create model
+    model = NeuMF(nb_users, nb_items,
+                  mf_dim=args.factors, mf_reg=0.,
+                  mlp_layer_sizes=args.layers,
+                  mlp_layer_regs=[0. for i in args.layers],
+                  split_final=args.split_final)
+    if use_cuda:
+        model = model.cuda()
+    msglogger.info(model)
+    msglogger.info("{} parameters".format(utils.count_parameters(model)))
+
+    # Save model text description
+    with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
+        file.write(str(model))
+
+    compression_scheduler = None
+    start_epoch = 0
+    optimizer = None
+    if args.load:
+        if training:
+            model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(model, args.load)
+            if args.reset_optimizer:
+                start_epoch = 0
+                optimizer = None
+        else:
+            model = apputils.load_lean_checkpoint(model, args.load)
+
+    # Add loss to graph
+    criterion = nn.BCEWithLogitsLoss()
+
+    if use_cuda:
+        criterion = criterion.cuda()
+
+    if training and optimizer is None:
+        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+        msglogger.info('Optimizer Type: %s', type(optimizer))
+        msglogger.info('Optimizer Args: %s', optimizer.defaults)
+
+    if args.compress:
+        compression_scheduler = distiller.file_config(model, optimizer, args.compress)
+        model.cuda()
+
+    # Create files for tracking training
+    valid_results_file = os.path.join(run_dir, 'valid_results.csv')
+
+    if args.qe_calibration or args.activation_histograms:
+        calib = {'portion': args.qe_calibration,
+                 'desc_str': 'quantization calibration stats',
+                 'collect_func': partial(distiller.data_loggers.collect_quant_stats, inplace_runtime_check=True,
+                                         disable_inplace_attrs=True)}
+        hists = {'portion': args.activation_histograms,
+                 'desc_str': 'activation histograms',
+                 'collect_func': partial(distiller.data_loggers.collect_histograms, activation_stats=None, nbins=2048,
+                                         save_hist_imgs=True)}
+        d = calib if args.qe_calibration else hists
+
+        distiller.utils.assign_layer_fq_names(model)
+        num_users = int(np.floor(len(test_ratings) * d['portion']))
+        msglogger.info(
+            "Generating {} based on {:.1%} of the test-set ({} users)".format(d['desc_str'], d['portion'], num_users))
+
+        test_fn = partial(val_epoch, ratings=test_ratings, negs=test_negs, K=args.topk, use_cuda=use_cuda,
+                          processes=args.processes, num_users=num_users)
+        d['collect_func'](model=model, test_fn=test_fn, save_dir=run_dir, classes=None)
+
+        return 0
+
+    if args.eval:
+        if args.quantize_eval and args.qe_calibration is None:
+            model.cpu()
+            quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args)
+            dummy_input = (torch.tensor([1]), torch.tensor([1]), torch.tensor([True], dtype=torch.bool))
+            quantizer.prepare_model(dummy_input)
+            model.cuda()
+
+        distiller.utils.assign_layer_fq_names(model)
+
+        if args.eval_fp16:
+            model = model.half()
+
+        # Calculate initial Hit Ratio and NDCG
+        begin = time.time()
+        hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk,
+                                use_cuda=use_cuda, processes=args.processes)
+        val_time = time.time() - begin
+        hit_rate = np.mean(hits)
+        msglogger.info('Initial HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, val_time = {val_time:.2f}'
+                       .format(K=args.topk, hit_rate=hit_rate, ndcg=np.mean(ndcgs), val_time=val_time))
+        hit_rate = 0
+
+        if args.quantize_eval:
+            checkpoint_name = 'quantized'
+            apputils.save_checkpoint(0, 'NCF', model, optimizer=None, extras={'quantized_hr@10': hit_rate},
+                                     name='_'.join([args.name, 'quantized']) if args.name else checkpoint_name,
+                                     dir=msglogger.logdir)
+        return 0
+
+    total_samples = len(train_dataloader.sampler)
+    steps_per_epoch = math.ceil(total_samples / args.batch_size)
+    best_hit_rate = 0
+    best_epoch = 0
+    for epoch in range(start_epoch, args.epochs):
+        msglogger.info('')
+        model.train()
+        losses = utils.AverageMeter()
+
+        begin = time.time()
+
+        if compression_scheduler:
+            compression_scheduler.on_epoch_begin(epoch, optimizer)
+
+        loader = tqdm.tqdm(train_dataloader)
+        for batch_index, (user, item, label) in enumerate(loader):
+            user = torch.autograd.Variable(user, requires_grad=False)
+            item = torch.autograd.Variable(item, requires_grad=False)
+            label = torch.autograd.Variable(label, requires_grad=False)
+            if use_cuda:
+                user = user.cuda(async=True)
+                item = item.cuda(async=True)
+                label = label.cuda(async=True)
+
+            if compression_scheduler:
+                compression_scheduler.on_minibatch_begin(epoch, batch_index, steps_per_epoch, optimizer)
+
+            outputs = model(user, item, torch.tensor([False], dtype=torch.bool))
+            loss = criterion(outputs, label)
+
+            if compression_scheduler:
+                compression_scheduler.before_backward_pass(epoch, batch_index, steps_per_epoch, loss, optimizer,
+                                                           return_loss_components=False)
+
+            losses.update(loss.data.item(), user.size(0))
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if compression_scheduler:
+                compression_scheduler.on_minibatch_end(epoch, batch_index, steps_per_epoch, optimizer)
+
+            # Save stats to file
+            description = ('Epoch {} Loss {loss.val:.4f} ({loss.avg:.4f})'
+                           .format(epoch, loss=losses))
+            loader.set_description(description)
+
+            steps_completed = batch_index + 1
+            if steps_completed % args.log_freq == 0:
+                stats_dict = OrderedDict()
+                stats_dict['Loss'] = losses.avg
+                stats = ('Performance/Training/', stats_dict)
+                params = model.named_parameters() if args.log_params_histograms else None
+                distiller.log_training_progress(stats, params, epoch, steps_completed, steps_per_epoch, args.log_freq,
+                                                [tflogger])
+
+                tflogger.log_model_buffers(model, ['tracked_min', 'tracked_max'], 'Quant/Train/Acts/TrackedMinMax',
+                                           epoch, steps_completed, steps_per_epoch, args.log_freq)
+
+        train_time = time.time() - begin
+        begin = time.time()
+        hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk,
+                                use_cuda=use_cuda, output=valid_results_file,
+                                epoch=epoch, processes=args.processes)
+        val_time = time.time() - begin
+
+        if compression_scheduler:
+            compression_scheduler.on_epoch_end(epoch, optimizer)
+
+        hit_rate = np.mean(hits)
+        mean_ndcgs = np.mean(ndcgs)
+
+        stats_dict = OrderedDict()
+        stats_dict['HR@{0}'.format(args.topk)] = hit_rate
+        stats_dict['NDCG@{0}'.format(args.topk)] = mean_ndcgs
+        stats = ('Performance/Validation/', stats_dict)
+        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1,
+                                        loggers=[tflogger])
+
+        msglogger.info('Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, AvgTrainLoss = {loss.avg:.4f}, '
+                       'train_time = {train_time:.2f}, val_time = {val_time:.2f}'.format(
+                            epoch=epoch, K=args.topk, hit_rate=hit_rate, ndcg=mean_ndcgs,
+                            loss=losses, train_time=train_time, val_time=val_time))
+
+        is_best = False
+        if hit_rate > best_hit_rate:
+            best_hit_rate = hit_rate
+            is_best = True
+            best_epoch = epoch
+        extras = {'current_hr@10': hit_rate,
+                  'best_hr@10': best_hit_rate,
+                  'best_epoch': best_epoch}
+        apputils.save_checkpoint(epoch, 'NCF', model, optimizer, compression_scheduler, extras, is_best, dir=run_dir)
+
+        if args.threshold is not None:
+            if np.mean(hits) >= args.threshold:
+                msglogger.info("Hit threshold of {}".format(args.threshold))
+                break
+
+
+if __name__ == '__main__':
+    try:
+        main()
+    except KeyboardInterrupt:
+        print("\n-- KeyboardInterrupt --")
+    finally:
+        if msglogger is not None:
+            msglogger.info('')
+            msglogger.info('Log file for this run: ' + os.path.realpath(msglogger.log_filename))
diff --git a/examples/ncf/neumf.py b/examples/ncf/neumf.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6179f00a23ebb7c3354228a5d554191d8e1f638
--- /dev/null
+++ b/examples/ncf/neumf.py
@@ -0,0 +1,120 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+import distiller.modules
+
+import logging
+import os
+msglogger = logging.getLogger()
+
+
+class NeuMF(nn.Module):
+    def __init__(self, nb_users, nb_items,
+                 mf_dim, mf_reg,
+                 mlp_layer_sizes, mlp_layer_regs, split_final=False):
+        if len(mlp_layer_sizes) != len(mlp_layer_regs):
+            raise RuntimeError('u dummy, layer_sizes != layer_regs!')
+        if mlp_layer_sizes[0] % 2 != 0:
+            raise RuntimeError('u dummy, mlp_layer_sizes[0] % 2 != 0')
+        super(NeuMF, self).__init__()
+
+        self.mf_dim = mf_dim
+        self.mlp_layer_sizes = mlp_layer_sizes
+
+        nb_mlp_layers = len(mlp_layer_sizes)
+
+        # TODO: regularization?
+        self.mf_user_embed = nn.Embedding(nb_users, mf_dim)
+        self.mf_item_embed = nn.Embedding(nb_items, mf_dim)
+        self.mlp_user_embed = nn.Embedding(nb_users, mlp_layer_sizes[0] // 2)
+        self.mlp_item_embed = nn.Embedding(nb_items, mlp_layer_sizes[0] // 2)
+
+        self.mf_mult = distiller.modules.EltwiseMult()
+        self.mlp_concat = distiller.modules.Concat(dim=1)
+
+        self.mlp = nn.ModuleList()
+        self.mlp_relu = nn.ModuleList()
+        for i in range(1, nb_mlp_layers):
+            self.mlp.extend([nn.Linear(mlp_layer_sizes[i - 1], mlp_layer_sizes[i])])  # noqa: E501
+            self.mlp_relu.extend([nn.ReLU()])
+
+        self.split_final = split_final
+        if not split_final:
+            self.final_concat = distiller.modules.Concat(dim=1)
+            self.final = nn.Linear(mlp_layer_sizes[-1] + mf_dim, 1)
+        else:
+            self.final_mlp = nn.Linear(mlp_layer_sizes[-1], 1)
+            self.final_mf = nn.Linear(mf_dim, 1)
+            self.final_add = distiller.modules.EltwiseAdd()
+
+        self.sigmoid = nn.Sigmoid()
+
+        self.mf_user_embed.weight.data.normal_(0., 0.01)
+        self.mf_item_embed.weight.data.normal_(0., 0.01)
+        self.mlp_user_embed.weight.data.normal_(0., 0.01)
+        self.mlp_item_embed.weight.data.normal_(0., 0.01)
+
+        def golorot_uniform(layer):
+            fan_in, fan_out = layer.in_features, layer.out_features
+            limit = np.sqrt(6. / (fan_in + fan_out))
+            layer.weight.data.uniform_(-limit, limit)
+
+        def lecunn_uniform(layer):
+            fan_in, fan_out = layer.in_features, layer.out_features  # noqa: F841, E501
+            limit = np.sqrt(3. / fan_in)
+            layer.weight.data.uniform_(-limit, limit)
+
+        for layer in self.mlp:
+            if type(layer) != nn.Linear:
+                continue
+            golorot_uniform(layer)
+        if not split_final:
+            lecunn_uniform(self.final)
+        else:
+            lecunn_uniform(self.final_mlp)
+            lecunn_uniform(self.final_mf)
+
+    def load_state_dict(self, state_dict, strict=True):
+        if 'final.weight' in state_dict and self.split_final:
+            # Loading no-split checkpoint into split model
+
+            # MF weights come first, then MLP
+            final_weight = state_dict.pop('final.weight')
+            state_dict['final_mf.weight'] = final_weight[0][:self.mf_dim].unsqueeze(0)
+            state_dict['final_mlp.weight'] = final_weight[0][self.mf_dim:].unsqueeze(0)
+
+            # Split bias 50-50
+            final_bias = state_dict.pop('final.bias')
+            state_dict['final_mf.bias'] = final_bias * 0.5
+            state_dict['final_mlp.bias'] = final_bias * 0.5
+        elif 'final_mf.weight' in state_dict and not self.split_final:
+            # Loading split checkpoint into no-split model
+            state_dict['final.weight'] = torch.cat((state_dict.pop('final_mf.weight')[0],
+                                                    state_dict.pop('final_mlp.weight')[0])).unsqueeze(0)
+            state_dict['final.bias'] = state_dict.pop('final_mf.bias') + state_dict.pop('final_mlp.bias')
+
+        super(NeuMF, self).load_state_dict(state_dict, strict)
+
+    def forward(self, user, item, sigmoid):
+        xmfu = self.mf_user_embed(user)
+        xmfi = self.mf_item_embed(item)
+        xmf = self.mf_mult(xmfu, xmfi)
+
+        xmlpu = self.mlp_user_embed(user)
+        xmlpi = self.mlp_item_embed(item)
+        xmlp = self.mlp_concat(xmlpu, xmlpi)
+        for i, (layer, act) in enumerate(zip(self.mlp, self.mlp_relu)):
+            xmlp = layer(xmlp)
+            xmlp = act(xmlp)
+
+        if not self.split_final:
+            x = self.final_concat(xmf, xmlp)
+            x = self.final(x)
+        else:
+            xmf = self.final_mf(xmf)
+            xmlp = self.final_mlp(xmlp)
+            x = self.final_add(xmf, xmlp)
+        if sigmoid:
+            x = self.sigmoid(x)
+        return x
diff --git a/examples/ncf/quantization_stats_no_split.yaml b/examples/ncf/quantization_stats_no_split.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f32232c563dce72256d95b8a9e36363d5fe1abf4
--- /dev/null
+++ b/examples/ncf/quantization_stats_no_split.yaml
@@ -0,0 +1,294 @@
+mf_user_embed:
+  inputs:
+    0:
+      min: 0
+      max: 13848
+      avg_min: 6924.0
+      avg_max: 6924.0
+      mean: 6924.0
+      std: 3997.8620729189115
+      shape: (1000)
+  output:
+    min: -2.3631081581115723
+    max: 2.634204864501953
+    avg_min: -1.071664170499179
+    avg_max: 1.10968593620068
+    mean: 0.004670127670715561
+    std: 0.4853287812380178
+    shape: (1000, 64)
+mf_item_embed:
+  inputs:
+    0:
+      min: 0
+      max: 26743
+      avg_min: 32.519748718318944
+      avg_max: 26716.69644017626
+      mean: 13429.908568849716
+      std: 7695.305228284578
+      shape: (1000)
+  output:
+    min: -3.886291742324829
+    max: 3.3996453285217285
+    avg_min: -1.8537866596848787
+    avg_max: 1.733846694667886
+    mean: -0.1020138903885282
+    std: 0.9169524390367664
+    shape: (1000, 64)
+mlp_user_embed:
+  inputs:
+    0:
+      min: 0
+      max: 13848
+      avg_min: 6924.0
+      avg_max: 6924.0
+      mean: 6924.0
+      std: 3997.8620729189115
+      shape: (1000)
+  output:
+    min: -2.3077099323272705
+    max: 2.019761323928833
+    avg_min: -0.8393908124364596
+    avg_max: 0.8563097013907461
+    mean: -0.0058890159863465566
+    std: 0.33145011084640036
+    shape: (1000, 128)
+mlp_item_embed:
+  inputs:
+    0:
+      min: 0
+      max: 26743
+      avg_min: 32.519748718318944
+      avg_max: 26716.69644017626
+      mean: 13429.908568849716
+      std: 7695.305228284578
+      shape: (1000)
+  output:
+    min: -4.184338569641113
+    max: 3.6927380561828613
+    avg_min: -1.6649888157199006
+    avg_max: 1.5336890253437756
+    mean: 0.03420270613169851
+    std: 0.6040079522209375
+    shape: (1000, 128)
+mf_mult:
+  inputs:
+    0:
+      min: -2.3631081581115723
+      max: 2.634204864501953
+      avg_min: -1.071664170499179
+      avg_max: 1.10968593620068
+      mean: 0.004670127670715561
+      std: 0.4853287812380178
+      shape: (1000, 64)
+    1:
+      min: -3.886291742324829
+      max: 3.3996453285217285
+      avg_min: -1.8537866596848787
+      avg_max: 1.733846694667886
+      mean: -0.1020138903885282
+      std: 0.9169524390367664
+      shape: (1000, 64)
+  output:
+    min: -6.388758659362793
+    max: 7.461198329925537
+    avg_min: -1.4091147140708975
+    avg_max: 1.190038402591932
+    mean: -0.03373257358976464
+    std: 0.47984411373143276
+    shape: (1000, 64)
+mlp_concat:
+  inputs:
+    0:
+      min: -2.3077099323272705
+      max: 2.019761323928833
+      avg_min: -0.8393908124364596
+      avg_max: 0.8563097013907461
+      mean: -0.0058890159863465566
+      std: 0.33145011084640036
+      shape: (1000, 128)
+    1:
+      min: -4.184338569641113
+      max: 3.6927380561828613
+      avg_min: -1.6649888157199006
+      avg_max: 1.5336890253437756
+      mean: 0.03420270613169851
+      std: 0.6040079522209375
+      shape: (1000, 128)
+  output:
+    min: -4.184338569641113
+    max: 3.6927380561828613
+    avg_min: -1.728540082865832
+    avg_max: 1.5992073092841723
+    mean: 0.014156845084580252
+    std: 0.4875919206474072
+    shape: (1000, 256)
+mlp.0:
+  inputs:
+    0:
+      min: -4.184338569641113
+      max: 3.6927380561828613
+      avg_min: -1.728540082865832
+      avg_max: 1.5992073092841723
+      mean: 0.014156845084580252
+      std: 0.4875919206474072
+      shape: (1000, 256)
+  output:
+    min: -25.551782608032227
+    max: 30.255319595336914
+    avg_min: -10.410937029339797
+    avg_max: 12.199230058019891
+    mean: -1.3596681231175995
+    std: 3.435064250450655
+    shape: (1000, 256)
+mlp.1:
+  inputs:
+    0:
+      min: 0.0
+      max: 30.255319595336914
+      avg_min: 0.0
+      avg_max: 12.199230058019891
+      mean: 0.6925220241296313
+      std: 1.6946167814794522
+      shape: (1000, 256)
+  output:
+    min: -231.78152465820312
+    max: 82.65782165527344
+    avg_min: -60.22239387931445
+    avg_max: 21.538379845476946
+    mean: -11.969065733546032
+    std: 16.571529820325505
+    shape: (1000, 128)
+mlp.2:
+  inputs:
+    0:
+      min: 0.0
+      max: 82.65782165527344
+      avg_min: 0.0
+      avg_max: 21.538379845476946
+      mean: 1.362308199108358
+      std: 3.989317218613674
+      shape: (1000, 128)
+  output:
+    min: -235.94625854492188
+    max: 203.7071990966797
+    avg_min: -54.71186078950498
+    avg_max: 29.957007628019554
+    mean: -7.084710118819055
+    std: 21.72214522135367
+    shape: (1000, 64)
+mlp_relu.0:
+  inputs:
+    0:
+      min: -25.551782608032227
+      max: 30.255319595336914
+      avg_min: -10.410937029339797
+      avg_max: 12.199230058019891
+      mean: -1.3596681231175995
+      std: 3.435064250450655
+      shape: (1000, 256)
+  output:
+    min: 0.0
+    max: 30.255319595336914
+    avg_min: 0.0
+    avg_max: 12.199230058019891
+    mean: 0.6925220241296313
+    std: 1.6946167814794522
+    shape: (1000, 256)
+mlp_relu.1:
+  inputs:
+    0:
+      min: -231.78152465820312
+      max: 82.65782165527344
+      avg_min: -60.22239387931445
+      avg_max: 21.538379845476946
+      mean: -11.969065733546032
+      std: 16.571529820325505
+      shape: (1000, 128)
+  output:
+    min: 0.0
+    max: 82.65782165527344
+    avg_min: 0.0
+    avg_max: 21.538379845476946
+    mean: 1.362308199108358
+    std: 3.989317218613674
+    shape: (1000, 128)
+mlp_relu.2:
+  inputs:
+    0:
+      min: -235.94625854492188
+      max: 203.7071990966797
+      avg_min: -54.71186078950498
+      avg_max: 29.957007628019554
+      mean: -7.084710118819055
+      std: 21.72214522135367
+      shape: (1000, 64)
+  output:
+    min: 0.0
+    max: 203.7071990966797
+    avg_min: 0.0
+    avg_max: 29.95700772112138
+    mean: 4.937804440873921
+    std: 11.42688696572864
+    shape: (1000, 64)
+final_concat:
+  inputs:
+    0:
+      min: -6.388758659362793
+      max: 7.461198329925537
+      avg_min: -1.4091147140708975
+      avg_max: 1.190038402591932
+      mean: -0.03373257358976464
+      std: 0.47984411373143276
+      shape: (1000, 64)
+    1:
+      min: 0.0
+      max: 203.7071990966797
+      avg_min: 0.0
+      avg_max: 29.95700772112138
+      mean: 4.937804440873921
+      std: 11.42688696572864
+      shape: (1000, 64)
+  output:
+    min: -6.388758659362793
+    max: 203.7071990966797
+    avg_min: -1.4091147140708975
+    avg_max: 29.95700872795735
+    mean: 2.452035934175203
+    std: 8.46057860792389
+    shape: (1000, 128)
+final:
+  inputs:
+    0:
+      min: -6.388758659362793
+      max: 203.7071990966797
+      avg_min: -1.4091147140708975
+      avg_max: 29.95700872795735
+      mean: 2.452035934175203
+      std: 8.46057860792389
+      shape: (1000, 128)
+  output:
+    min: -264.23663330078125
+    max: 10.719743728637695
+    avg_min: -64.09727749207161
+    avg_max: 4.514405594118789
+    mean: -27.331936557333087
+    std: 29.674832823876194
+    shape: (1000, 1)
+sigmoid:
+  inputs:
+    0:
+      min: -264.23663330078125
+      max: 10.719743728637695
+      avg_min: -64.09727749207161
+      avg_max: 4.514405594118789
+      mean: -27.331936557333087
+      std: 29.674832823876194
+      shape: (1000, 1)
+  output:
+    min: 0.0
+    max: 0.9999779462814331
+    avg_min: 7.22119551814259e-08
+    avg_max: 0.9796236092019589
+    mean: 0.025780337072657727
+    std: 0.11967490732565136
+    shape: (1000, 1)
diff --git a/examples/ncf/quantization_stats_split.yaml b/examples/ncf/quantization_stats_split.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..afbc00409ba2d336701d5a21095b67c68400db21
--- /dev/null
+++ b/examples/ncf/quantization_stats_split.yaml
@@ -0,0 +1,312 @@
+mf_user_embed:
+  inputs:
+    0:
+      min: 0
+      max: 13848
+      avg_min: 6924.0
+      avg_max: 6924.0
+      mean: 6924.0
+      std: 3997.8620729189115
+      shape: (1000)
+  output:
+    min: -2.3631081581115723
+    max: 2.634204864501953
+    avg_min: -1.071664170499179
+    avg_max: 1.10968593620068
+    mean: 0.004670127670715561
+    std: 0.4853287812380178
+    shape: (1000, 64)
+mf_item_embed:
+  inputs:
+    0:
+      min: 0
+      max: 26743
+      avg_min: 32.519748718318944
+      avg_max: 26716.69644017626
+      mean: 13429.908568849716
+      std: 7695.305228284578
+      shape: (1000)
+  output:
+    min: -3.886291742324829
+    max: 3.3996453285217285
+    avg_min: -1.8537866596848787
+    avg_max: 1.733846694667886
+    mean: -0.1020138903885282
+    std: 0.9169524390367664
+    shape: (1000, 64)
+mlp_user_embed:
+  inputs:
+    0:
+      min: 0
+      max: 13848
+      avg_min: 6924.0
+      avg_max: 6924.0
+      mean: 6924.0
+      std: 3997.8620729189115
+      shape: (1000)
+  output:
+    min: -2.3077099323272705
+    max: 2.019761323928833
+    avg_min: -0.8393908124364596
+    avg_max: 0.8563097013907461
+    mean: -0.0058890159863465566
+    std: 0.33145011084640036
+    shape: (1000, 128)
+mlp_item_embed:
+  inputs:
+    0:
+      min: 0
+      max: 26743
+      avg_min: 32.519748718318944
+      avg_max: 26716.69644017626
+      mean: 13429.908568849716
+      std: 7695.305228284578
+      shape: (1000)
+  output:
+    min: -4.184338569641113
+    max: 3.6927380561828613
+    avg_min: -1.6649888157199006
+    avg_max: 1.5336890253437756
+    mean: 0.03420270613169851
+    std: 0.6040079522209375
+    shape: (1000, 128)
+mf_mult:
+  inputs:
+    0:
+      min: -2.3631081581115723
+      max: 2.634204864501953
+      avg_min: -1.071664170499179
+      avg_max: 1.10968593620068
+      mean: 0.004670127670715561
+      std: 0.4853287812380178
+      shape: (1000, 64)
+    1:
+      min: -3.886291742324829
+      max: 3.3996453285217285
+      avg_min: -1.8537866596848787
+      avg_max: 1.733846694667886
+      mean: -0.1020138903885282
+      std: 0.9169524390367664
+      shape: (1000, 64)
+  output:
+    min: -6.388758659362793
+    max: 7.461198329925537
+    avg_min: -1.4091147140708975
+    avg_max: 1.190038402591932
+    mean: -0.03373257358976464
+    std: 0.47984411373143276
+    shape: (1000, 64)
+mlp_concat:
+  inputs:
+    0:
+      min: -2.3077099323272705
+      max: 2.019761323928833
+      avg_min: -0.8393908124364596
+      avg_max: 0.8563097013907461
+      mean: -0.0058890159863465566
+      std: 0.33145011084640036
+      shape: (1000, 128)
+    1:
+      min: -4.184338569641113
+      max: 3.6927380561828613
+      avg_min: -1.6649888157199006
+      avg_max: 1.5336890253437756
+      mean: 0.03420270613169851
+      std: 0.6040079522209375
+      shape: (1000, 128)
+  output:
+    min: -4.184338569641113
+    max: 3.6927380561828613
+    avg_min: -1.728540082865832
+    avg_max: 1.5992073092841723
+    mean: 0.014156845084580252
+    std: 0.4875919206474072
+    shape: (1000, 256)
+mlp.0:
+  inputs:
+    0:
+      min: -4.184338569641113
+      max: 3.6927380561828613
+      avg_min: -1.728540082865832
+      avg_max: 1.5992073092841723
+      mean: 0.014156845084580252
+      std: 0.4875919206474072
+      shape: (1000, 256)
+  output:
+    min: -25.551782608032227
+    max: 30.255319595336914
+    avg_min: -10.410937029339797
+    avg_max: 12.199230058019891
+    mean: -1.3596681231175995
+    std: 3.435064250450655
+    shape: (1000, 256)
+mlp.1:
+  inputs:
+    0:
+      min: 0.0
+      max: 30.255319595336914
+      avg_min: 0.0
+      avg_max: 12.199230058019891
+      mean: 0.6925220241296313
+      std: 1.6946167814794522
+      shape: (1000, 256)
+  output:
+    min: -231.78152465820312
+    max: 82.65782165527344
+    avg_min: -60.22239387931445
+    avg_max: 21.538379845476946
+    mean: -11.969065733546032
+    std: 16.571529820325505
+    shape: (1000, 128)
+mlp.2:
+  inputs:
+    0:
+      min: 0.0
+      max: 82.65782165527344
+      avg_min: 0.0
+      avg_max: 21.538379845476946
+      mean: 1.362308199108358
+      std: 3.989317218613674
+      shape: (1000, 128)
+  output:
+    min: -235.94625854492188
+    max: 203.7071990966797
+    avg_min: -54.71186078950498
+    avg_max: 29.957007628019554
+    mean: -7.084710118819055
+    std: 21.72214522135367
+    shape: (1000, 64)
+mlp_relu.0:
+  inputs:
+    0:
+      min: -25.551782608032227
+      max: 30.255319595336914
+      avg_min: -10.410937029339797
+      avg_max: 12.199230058019891
+      mean: -1.3596681231175995
+      std: 3.435064250450655
+      shape: (1000, 256)
+  output:
+    min: 0.0
+    max: 30.255319595336914
+    avg_min: 0.0
+    avg_max: 12.199230058019891
+    mean: 0.6925220241296313
+    std: 1.6946167814794522
+    shape: (1000, 256)
+mlp_relu.1:
+  inputs:
+    0:
+      min: -231.78152465820312
+      max: 82.65782165527344
+      avg_min: -60.22239387931445
+      avg_max: 21.538379845476946
+      mean: -11.969065733546032
+      std: 16.571529820325505
+      shape: (1000, 128)
+  output:
+    min: 0.0
+    max: 82.65782165527344
+    avg_min: 0.0
+    avg_max: 21.538379845476946
+    mean: 1.362308199108358
+    std: 3.989317218613674
+    shape: (1000, 128)
+mlp_relu.2:
+  inputs:
+    0:
+      min: -235.94625854492188
+      max: 203.7071990966797
+      avg_min: -54.71186078950498
+      avg_max: 29.957007628019554
+      mean: -7.084710118819055
+      std: 21.72214522135367
+      shape: (1000, 64)
+  output:
+    min: 0.0
+    max: 203.7071990966797
+    avg_min: 0.0
+    avg_max: 29.95700772112138
+    mean: 4.937804440873921
+    std: 11.42688696572864
+    shape: (1000, 64)
+final_mlp:
+  inputs:
+    0:
+      min: 0.0
+      max: 203.7071990966797
+      avg_min: 0.0
+      avg_max: 29.95700772112138
+      mean: 4.937804440873921
+      std: 11.42688696572864
+      shape: (1000, 64)
+  output:
+    min: -283.5218200683594
+    max: 40.86720275878906
+    avg_min: -60.916563263919905
+    avg_max: 6.9745166829700596
+    mean: -22.90735553673859
+    std: 35.900340843176544
+    shape: (1000, 1)
+final_mf:
+  inputs:
+    0:
+      min: -6.388758659362793
+      max: 7.461198329925537
+      avg_min: -1.4091147140708975
+      avg_max: 1.190038402591932
+      mean: -0.03373257358976464
+      std: 0.47984411373143276
+      shape: (1000, 64)
+  output:
+    min: -54.07410430908203
+    max: 47.101890563964844
+    avg_min: -16.45866186288787
+    avg_max: 7.688797266098178
+    mean: -4.42458063103854
+    std: 9.54055961838881
+    shape: (1000, 1)
+final_add:
+  inputs:
+    0:
+      min: -54.07410430908203
+      max: 47.101890563964844
+      avg_min: -16.45866186288787
+      avg_max: 7.688797266098178
+      mean: -4.42458063103854
+      std: 9.54055961838881
+      shape: (1000, 1)
+    1:
+      min: -283.5218200683594
+      max: 40.86720275878906
+      avg_min: -60.916563263919905
+      avg_max: 6.9745166829700596
+      mean: -22.90735553673859
+      std: 35.900340843176544
+      shape: (1000, 1)
+  output:
+    min: -264.23663330078125
+    max: 10.719744682312012
+    avg_min: -64.09727644866952
+    avg_max: 4.514405736883009
+    mean: -27.33193622736199
+    std: 29.674832823876194
+    shape: (1000, 1)
+sigmoid:
+  inputs:
+    0:
+      min: -264.23663330078125
+      max: 10.719744682312012
+      avg_min: -64.09727644866952
+      avg_max: 4.514405736883009
+      mean: -27.33193622736199
+      std: 29.674832823876194
+      shape: (1000, 1)
+  output:
+    min: 0.0
+    max: 0.9999779462814331
+    avg_min: 7.221196356095192e-08
+    avg_max: 0.9796236107793396
+    mean: 0.025780336994990046
+    std: 0.11967490732565136
+    shape: (1000, 1)
diff --git a/examples/ncf/utils.py b/examples/ncf/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4395830c050b54fbfc4212dff4bbf5d3c755ed00
--- /dev/null
+++ b/examples/ncf/utils.py
@@ -0,0 +1,41 @@
+import os
+import json
+from functools import reduce
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def count_parameters(model):
+    c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters())
+    return sum(c)
+
+
+def save_config(config, run_dir):
+    path = os.path.join(run_dir, "config_{}.json".format(config['timestamp']))
+    with open(path, 'w') as config_file:
+        json.dump(config, config_file)
+        config_file.write('\n')
+
+
+def save_result(result, path):
+    write_heading = not os.path.exists(path)
+    with open(path, mode='a') as out:
+        if write_heading:
+            out.write(",".join([str(k) for k, v in result.items()]) + '\n')
+        out.write(",".join([str(v) for k, v in result.items()]) + '\n')
diff --git a/examples/ncf/verify_dataset.sh b/examples/ncf/verify_dataset.sh
new file mode 100755
index 0000000000000000000000000000000000000000..208d7602a8fad8bf3151edf90e8ecd2641195e14
--- /dev/null
+++ b/examples/ncf/verify_dataset.sh
@@ -0,0 +1,44 @@
+function get_checker {
+    if [[ "$OSTYPE" == "darwin"* ]]; then
+        checkmd5=md5
+    else
+        checkmd5=md5sum
+    fi
+
+    echo $checkmd5
+}
+
+
+function verify_1m {
+    # From: curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip.md5
+    hash=<(echo "MD5 (ml-1m.zip) = c4d9eecfca2ab87c1945afe126590906")
+    local checkmd5=$(get_checker)
+    if diff <($checkmd5 ml-1m.zip) $hash &> /dev/null
+    then
+        echo "PASSED"
+    else
+        echo "FAILED"
+    fi
+}
+
+function verify_20m {
+    # From: curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip.md5
+    hash=<(echo "MD5 (ml-20m.zip) = cd245b17a1ae2cc31bb14903e1204af3")
+    local checkmd5=$(get_checker)
+
+    if diff <($checkmd5 ml-20m.zip) $hash &> /dev/null
+    then
+        echo "PASSED"
+    else
+        echo "FAILED"
+    fi
+
+}
+
+
+if [[ $1 == "ml-1m" ]]
+then
+    verify_1m
+else
+    verify_20m
+fi
diff --git a/jupyter/truncated_svd.ipynb b/jupyter/truncated_svd.ipynb
index 7331d3c54b228f57364c1d27a3c17fbf9b94ecc3..07f41d657c30f42480b7ef20c58de434aaa1ed71 100644
--- a/jupyter/truncated_svd.ipynb
+++ b/jupyter/truncated_svd.ipynb
@@ -40,12 +40,54 @@
     "    - Top1: 75.65  \n",
     "    - Top5: 92.75\n",
     "    \n",
-    "    Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000) "
+    "    Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)\n",
+    "    \n",
+    "## Details\n",
+    "\n",
+    "[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.\n",
+    "\n",
+    "A Linear (fully-connected) layer performs: y = Wx + b   (or y = xW$^T$ + b)<br>\n",
+    "We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b\n",
+    "\n",
+    "So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough” and also accelerates the computation of Wx.\n",
+    "\n",
+    "We choose some lower-rank, k, such that k<m (preferably k<<m).<br>\n",
+    "TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).\n",
+    "\n",
+    "After TSVD we have:\n",
+    "U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>\n",
+    "y ~ (U’S’V’)x + b<br>\n",
+    "y ~ (U’(S’V’))x + b<br>\n",
+    "\n",
+    "We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>\n",
+    "y ~ (U’A)x + b<br>\n",
+    "y ~ U’(Ax) + b<br>\n",
+    "\n",
+    "Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:\n",
+    "\n",
+    " - m * n weights coefficients<br>\n",
+    " - m * n FLOPs (for batch size = 1)<br>\n",
+    "\n",
+    "After TSVD we have:\n",
+    "\n",
+    " - mk + kn = k*(m+n) weights coefficients<br>\n",
+    " - kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>\n",
+    "\n",
+    "To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>\n",
+    "Let’s rewrite k in terms of m: k = tm<br>\n",
+    "m * n > tm*(m+n)<br>\n",
+    "n > t*(m+n)<br>\n",
+    "n / (m+n) >= t<br>\n",
+    "\n",
+    "This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)\n",
+    "\n",
+    "In the example notebook: m = 1000; n=2048<br>\n",
+    "So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -80,7 +122,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -105,14 +147,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
     "BATCH_SIZE = 4\n",
     "\n",
     "# Data loader\n",
-    "test_loader = imagenet_load_data(\"../../data.imagenet/\", \n",
+    "test_loader = imagenet_load_data(\"/datasets/imagenet/\", \n",
     "                                 batch_size=BATCH_SIZE, \n",
     "                                 num_workers=2)\n",
     "    \n",
@@ -136,11 +178,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "metadata": {
     "scrolled": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "fc_layer(2048, 1000)\n",
+      "W = torch.Size([1000, 2048])\n",
+      "U = torch.Size([1000, 400])\n",
+      "SV = torch.Size([400, 2048])\n"
+     ]
+    }
+   ],
    "source": [
     "# Load the various models\n",
     "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n",
@@ -170,11 +223,11 @@
     "\n",
     "\n",
     "class TruncatedSVD(nn.Module):\n",
-    "    def __init__(self, replaced_gemm, gemm_weights):\n",
-    "        super(TruncatedSVD,self).__init__()\n",
+    "    def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):\n",
+    "        super().__init__()\n",
     "        self.replaced_gemm = replaced_gemm\n",
     "        print(\"W = {}\".format(gemm_weights.shape))\n",
-    "        self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))\n",
+    "        self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))\n",
     "        print(\"U = {}\".format(self.U.shape))\n",
     "        \n",
     "        self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n",
@@ -182,20 +235,21 @@
     "        \n",
     "        print(\"SV = {}\".format(self.SV.shape))\n",
     "        self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n",
-    "        self.fc_sv.weight.data = self.SV#.t()\n",
-    "        \n",
+    "        self.fc_sv.weight.data = self.SV#.t()        \n",
     "\n",
     "    def forward(self, x):\n",
     "        x = self.fc_sv.forward(x)\n",
     "        x = self.fc_u.forward(x)\n",
     "        return x\n",
     "\n",
+    "    \n",
     "def replace(model):\n",
     "    fc_weights = model.state_dict()['fc.weight']\n",
     "    fc_layer = model.fc\n",
     "    print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n",
-    "    model.fc = TruncatedSVD(fc_layer, fc_weights)\n",
+    "    model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)\n",
     "\n",
+    "    \n",
     "from copy import deepcopy\n",
     "resnet50 = deepcopy(resnet50)\n",
     "replace(resnet50)"
@@ -203,18 +257,34 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {
     "scrolled": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "progress: 51200 images\n",
+      "progress: 102400 images\n",
+      "progress: 153600 images\n",
+      "progress: 204800 images\n",
+      "progress: 256000 images\n",
+      "progress: 307200 images\n",
+      "progress: 358400 images\n",
+      "Top1: 75.70   Top5: 92.76\n",
+      "Duration:  168.111492395401\n"
+     ]
+    }
+   ],
    "source": [
     "# Standard loop to test the accuracy of a model.\n",
     "\n",
     "import time\n",
     "import torchnet.meter as tnt\n",
     "t0 = time.time()\n",
-    "test_loader = imagenet_load_data(\"../../data.imagenet\", \n",
+    "test_loader = imagenet_load_data(\"/datasets/imagenet\", \n",
     "                                 batch_size=64, \n",
     "                                 num_workers=4,\n",
     "                                 shuffle=False)\n",
@@ -234,6 +304,13 @@
     "t2 = time.time()\n",
     "print(\"Duration: \", t2-t0)"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
diff --git a/licenses/py_faster_rcnn-license.txt b/licenses/py_faster_rcnn-license.txt
new file mode 100755
index 0000000000000000000000000000000000000000..1ab42b27a3ac66e841a94d6f568be493efcde274
--- /dev/null
+++ b/licenses/py_faster_rcnn-license.txt
@@ -0,0 +1,81 @@
+Faster R-CNN
+
+The MIT License (MIT)
+
+Copyright (c) 2015 Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project, Faster R-CNN, incorporates material from the project(s)
+listed below (collectively, "Third Party Code").  Microsoft is not the
+original author of the Third Party Code.  The original copyright notice
+and license under which Microsoft received such Third Party Code are set
+out below. This Third Party Code is licensed to you under their original
+license terms set forth below.  Microsoft reserves all other rights not
+expressly granted, whether by implication, estoppel or otherwise.
+
+1.	Caffe, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014, 2015, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright
+over their contributions to Caffe. The project versioning records all
+such contribution and copyright details. If a contributor wants to
+further mark their specific copyright on a particular contribution,
+they should indicate their copyright solely in the commit message of
+the change when it is committed.
+
+The BSD 2-Clause License
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+
+1. Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in the
+documentation and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
diff --git a/requirements.txt b/requirements.txt
index 06e56c4711c81dce3432bb7ce938c2baea71fe3e..c69c7065753c638ef5c5a77c9d7077fbc6855c43 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -20,3 +20,4 @@ xlsxwriter>=1.1.1
 pretrainedmodels==0.7.4
 scikit-learn==0.21.2
 gym==0.12.5
+tqdm==4.33.0