diff --git a/.gitignore b/.gitignore index a1670e9fddf038bfdb2c6a14f0de7843322bba9d..110302734953a65f823713dda57f4f2f53b23a9c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ __pycache__/ .pytest_cache .cache pytest_collaterals/ +examples/ncf/run/ +examples/ncf/ml-20m* # GNMT sample examples/GNMT/data diff --git a/distiller/modules/tsvd.py b/distiller/modules/tsvd.py new file mode 100755 index 0000000000000000000000000000000000000000..aea20bc2e3726bf2dd8104926cd42a873f9bc56f --- /dev/null +++ b/distiller/modules/tsvd.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Truncated-SVD module. + +For an example of how truncated-SVD can be used, see this Jupyter notebook: +https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb + +""" + +def truncated_svd(W, l): + """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD. + + For the original implementation (MIT license), see Faster-RCNN: + https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py + We replaced numpy operations with pytorch operations (so that we can leverage the GPU). + + Arguments: + W: N x M weights matrix + l: number of singular values to retain + Returns: + Ul, L: matrices such that W \approx Ul*L + """ + + U, s, V = torch.svd(W, some=True) + + Ul = U[:, :l] + sl = s[:l] + V = V.t() + Vl = V[:l, :] + + SV = torch.mm(torch.diag(sl), Vl) + return Ul, SV + + +class TruncatedSVD(nn.Module): + def __init__(self, replaced_gemm, gemm_weights, preserve_ratio): + super().__init__() + self.replaced_gemm = replaced_gemm + print("W = {}".format(gemm_weights.shape)) + self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0))) + print("U = {}".format(self.U.shape)) + + self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda() + self.fc_u.weight.data = self.U + + print("SV = {}".format(self.SV.shape)) + self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda() + self.fc_sv.weight.data = self.SV#.t() + + def forward(self, x): + x = self.fc_sv.forward(x) + x = self.fc_u.forward(x) + return x diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 56454e7911d03cf60189c7695160401eaef20ca9..96b8dceb70956320c6fc67380ec85ad00cedec21 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -446,8 +446,7 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner): if fraction_to_prune == 0: return binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, - zeros_mask_dict, model, binary_map, - self.rounding_fn) + zeros_mask_dict, model, binary_map) return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index 06d5d9a31233798b1c4b6112a6d38cc73e0c18e8..e24c97686556d708d111f4b3bc7023bcff5985e9 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -16,7 +16,7 @@ from .quantizer import Quantizer from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \ - LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args,\ + LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \ RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 7bbedc42644e097c18f51f6458df7bddbab63c53..e632cef7259df7dc66357f8dceae952f691ad44e 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -1355,3 +1355,21 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): per_channel=perch) m.register_buffer(ptq.q_attr_name + '_scale', torch.ones_like(scale)) m.register_buffer(ptq.q_attr_name + '_zero_point', torch.zeros_like(zero_point)) + + +class NCFQuantAwareTrainQuantizer(QuantAwareTrainRangeLinearQuantizer): + def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_bias=32, + overrides=None, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False): + super(NCFQuantAwareTrainQuantizer, self).__init__(model, optimizer=optimizer, + bits_activations=bits_activations, + bits_weights=bits_weights, + bits_bias=bits_bias, + overrides=overrides, + mode=mode, ema_decay=ema_decay, + per_channel_wts=per_channel_wts, + quantize_inputs=False) + + self.replacement_factory[distiller.modules.EltwiseMult] = self.activation_replace_fn + self.replacement_factory[distiller.modules.Concat] = self.activation_replace_fn + self.replacement_factory[nn.Linear] = self.activation_replace_fn + # self.replacement_factory[nn.Sigmoid] = self.activation_replace_fn diff --git a/examples/ncf/MLPERF_LICENSE.md b/examples/ncf/MLPERF_LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..89ba5be4833f1b8418c869f63ab8ee1f6e881531 --- /dev/null +++ b/examples/ncf/MLPERF_LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 The MLPerf Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/ncf/README.md b/examples/ncf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eac9b0c5ed7963c996019a7899c9eb0d7c57dfbc --- /dev/null +++ b/examples/ncf/README.md @@ -0,0 +1,172 @@ +# NCF - Neural Collaborative Filtering + +The NCF implementation provided here is based on the implementation found in the MLPerf Training GitHub repository. +This sample is not based on the latest implementation in MLPerf, it is based on an earlier revision which uses the ml-20m dataset. The latest code uses a much larger dataset. We plan to move to the latest version in the near future. +You can fine the revision this sample is based on [here](https://github.com/mlperf/training/tree/fe17e837ed12974d15c86d5173fe8f2c188434d5/recommendation/pytorch). + +We've made several modifications to the code: +* Removed all MLPerf specific code including logging +* In `ncf.py`: + * Added calls to Distiller compression APIs + * Added progress indication in training and evaluation flows +* In `neumf.py`: + * Added option to split final the FC layer (the `split_final` parameter). See [below](#side-note-splitting-the-final-fc-layer). + * Replaced all functional calls with modules so they can be detected by Distiller, as per this [guide](https://nervanasystems.github.io/distiller/prepare_model_quant.html) in the Distiller docs. +* In `dataset.py`: + * Speed up data loading - On first data will is loaded from CSVs and then pickled. On subsequent runs the pickle is loaded. This is much faster than the original implementation, but still very slow. + * Added progress indication during data load process + +The sample command lines provided [below](#running-the-sample) focus on **post-training quantization**. We did integrate the capability to run quantization-aware training into `ncf.py`. We'll add examples for this at a later time. + +## Problem + +This task benchmarks recommendation with implicit feedback on the [MovieLens 20 Million (ml-20m) dataset](https://grouplens.org/datasets/movielens/20m/) with a [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569) model. +The model trains on binary information about whether or not a user interacted with a specific item. + +## Setup + +### Steps to configure machine + +* Install `unzip` and `curl` + + ```bash + sudo apt-get install unzip curl + ``` + +* Make sure the latest Distiller requirements are installed + + ```bash + # Relative to this sample directory + cd <distiller-repo-root> + pip install -e . + ``` + +* Download and verify data + + ```bash + cd <distiller-repo-root>/examples/ncf + # Creates ml-20.zip + source ../download_dataset.sh + # Confirms the MD5 checksum of ml-20.zip + source ../verify_dataset.sh + ``` + +## Running the Sample + +### Train a Base FP32 Model + +We train a model with the following parameters: + +* MLP Side + * Embedding size per user / item: 128 + * FC layer sizes: 256x256 --> 256x128 --> 128x64 +* MF (matrix factorization) Side + * Embedding size per user / item: 64 +* Therefore, the final FC layer size is: 128x1 + +Adam optimizer is used, with an initial learning rate of 0.0005. Batch size is 2048. Convergence is obtained after 7 epochs. + +```bash +python ncf.py ml-20m -l 0.0005 -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --processes 10 -o run/neumf/base_fp32 +... +Epoch 0 Loss 0.1179 (0.1469): 100%|█████████████████████████████| 48491/48491 [07:04<00:00, 114.23it/s] +Epoch 0 evaluation +Epoch 0: HR@10 = 0.5738, NDCG@10 = 0.3367, AvgTrainLoss = 0.1469, train_time = 424.52, val_time = 47.04 +... +Epoch 6 Loss 0.0914 (0.0943): 100%|█████████████████████████████| 48491/48491 [06:47<00:00, 118.90it/s] +Epoch 6 evaluation +Epoch 6: HR@10 = 0.6355, NDCG@10 = 0.3820, AvgTrainLoss = 0.0943, train_time = 407.84, val_time = 62.99 +``` + +The hit-rate of the base model is 63.55. + +### Side-Note: Splitting the Final FC Layer + +As mentioned above, we added an option to split the final FC layer of the model (the `split_final` parameter in `NeuMF.__init__`). + +The reasoning behind this is that the input to the final FC layer in NCF is a concatenation of the outputs of the MLP and MF "branches". These outputs have very different dynamic ranges. +In the model we just trained, the MLP branch output range is [0 .. 203] while the MF branch output range is [-6.3 .. 7.4]. When doing quantized concatenation, we have to accommodate the larger range, which leads to a large quantization error for the data that came from the MF branch. When quantizing to 8-bits, the MF branch will cover only 10 bins out of the 256 bins, which means just over 3-bits. +The mitigation we use is to split the final FC layer as follows: + +``` + Before Split: After Split: + ------------- ------------ + MF_OUT MLP_OUT MF_OUT MLP_OUT + \ / | | + \ / ---> MF_FC MLP_FC + CONCAT \ / + | \ / + FINAL_FC \ / + ADD +``` +After splitting, the two inputs to the add operation have ranges [-283 .. 40] from the MLP side and [-54 .. 47] from the MF side. While the problem isn't completely solved, it's much better than before. Now the MF covers 126 bins, which is almost 7-bits. + +Note that in FP32 the 2 modes are functionally identical. The split final option is for evaluation only, and we take care to convert the model trained without splitting into a split model when loading the checkpoint. + +### Collect Quantization Stats for Post-Training Quantization + +We generated stats for both the non-split and split case. These are the `quantization_stats_no_split.yaml` and `quantization_stats_split.yaml` files in the example folder. + +For reference, the command lines used to generate these are: + +```bash +python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1 +python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1 --split-final +``` +Note that `--qe-calibration 0.1` means that we use 10% of the test dataset for the stats collection. + +### Post-Training Quantization Experiments + +We'll use the following settings for quantization: + +* 8-bits for weights and activations: `--qeba 8 --qebw 8` +* Asymmetric: `--qem asym_u` +* Per-channel: `--qepc` + +Let's see the difference splitting the final FC layer makes in terms of overall accuracy: + +```bash +ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --qe-stats-file quantization_stats_no_split.yaml +... +Initial HR@10 = 0.4954, NDCG@10 = 0.2802, val_time = 521.11 +``` + +```bash +ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --split-final --qe-stats-file quantization_stats_split.yaml +... +HR@10 = 0.6278, NDCG@10 = 0.3760, val_time = 601.87 +``` + +We can see that without splitting, we get ~14% degradation in hit-rate. With splitting we gain almost all of the accuracy back, with about 0.8% degradation. + +## Dataset / Environment + +### Publication / Attribution + +Harper, F. M. & Konstan, J. A. (2015), 'The MovieLens Datasets: History and Context', ACM Trans. Interact. Intell. Syst. 5(4), 19:1--19:19. + +### Data preprocessing + +1. Unzip +2. Remove users with less than 20 reviews +3. Create training and test data separation described below + +### Training and test data separation + +Positive training examples are all but the last item each user rated. +Negative training examples are randomly selected from the unrated items for each user. + +The last item each user rated is used as a positive example in the test set. +A fixed set of 999 unrated items are also selected to calculate hit rate at 10 for predicting the test item. + +### Training data order + +Data is traversed randomly with 4 negative examples selected on average for every positive example. + +## Model + +### Publication/Attribution + +Xiangnan He, Lizi Liao, Hanwang Zhang, Liqiang Nie, Xia Hu and Tat-Seng Chua (2017). [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569). In Proceedings of WWW '17, Perth, Australia, April 03-07, 2017. + +The author's original code is available at [hexiangnan/neural_collaborative_filtering](https://github.com/hexiangnan/neural_collaborative_filtering). diff --git a/examples/ncf/convert.py b/examples/ncf/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8e5d7ccd33ffaecaa0a63e3fd093386d70c7fb --- /dev/null +++ b/examples/ncf/convert.py @@ -0,0 +1,101 @@ +import os +from argparse import ArgumentParser +from collections import defaultdict + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from load import implicit_load + + +MIN_RATINGS = 20 + + +USER_COLUMN = 'user_id' +ITEM_COLUMN = 'item_id' + + +TRAIN_RATINGS_FILENAME = 'train-ratings.csv' +TEST_RATINGS_FILENAME = 'test-ratings.csv' +TEST_NEG_FILENAME = 'test-negative.csv' + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('path', type=str, + help='Path to reviews CSV file from MovieLens') + parser.add_argument('output', type=str, + help='Output directory for train and test CSV files') + parser.add_argument('-n', '--negatives', type=int, default=999, + help='Number of negative samples for each positive' + 'test example') + parser.add_argument('-s', '--seed', type=int, default=0, + help='Random seed to reproduce same negative samples') + return parser.parse_args() + + +def main(): + args = parse_args() + np.random.seed(args.seed) + + print("Loading raw data from {}".format(args.path)) + df = implicit_load(args.path, sort=False) + print("Filtering out users with less than {} ratings".format(MIN_RATINGS)) + grouped = df.groupby(USER_COLUMN) + df = grouped.filter(lambda x: len(x) >= MIN_RATINGS) + + print("Mapping original user and item IDs to new sequential IDs") + original_users = df[USER_COLUMN].unique() + original_items = df[ITEM_COLUMN].unique() + + user_map = {user: index for index, user in enumerate(original_users)} + item_map = {item: index for index, item in enumerate(original_items)} + + df[USER_COLUMN] = df[USER_COLUMN].apply(lambda user: user_map[user]) + df[ITEM_COLUMN] = df[ITEM_COLUMN].apply(lambda item: item_map[item]) + + assert df[USER_COLUMN].max() == len(original_users) - 1 + assert df[ITEM_COLUMN].max() == len(original_items) - 1 + + print("Creating list of items for each user") + # Need to sort before popping to get last item + df.sort_values(by='timestamp', inplace=True) + all_ratings = set(zip(df[USER_COLUMN], df[ITEM_COLUMN])) + user_to_items = defaultdict(list) + for row in tqdm(df.itertuples(), desc='Ratings', total=len(df)): + user_to_items[getattr(row, USER_COLUMN)].append(getattr(row, ITEM_COLUMN)) # noqa: E501 + + test_ratings = [] + test_negs = [] + all_items = set(range(len(original_items))) + print("Generating {} negative samples for each user" + .format(args.negatives)) + for user in tqdm(range(len(original_users)), desc='Users', total=len(original_users)): # noqa: E501 + test_item = user_to_items[user].pop() + + all_ratings.remove((user, test_item)) + all_negs = all_items - set(user_to_items[user]) + all_negs = sorted(list(all_negs)) # determinism + + test_ratings.append((user, test_item)) + test_negs.append(list(np.random.choice(all_negs, args.negatives))) + + print("Saving train and test CSV files to {}".format(args.output)) + df_train_ratings = pd.DataFrame(list(all_ratings)) + df_train_ratings['fake_rating'] = 1 + df_train_ratings.to_csv(os.path.join(args.output, TRAIN_RATINGS_FILENAME), + index=False, header=False, sep='\t') + + df_test_ratings = pd.DataFrame(test_ratings) + df_test_ratings['fake_rating'] = 1 + df_test_ratings.to_csv(os.path.join(args.output, TEST_RATINGS_FILENAME), + index=False, header=False, sep='\t') + + df_test_negs = pd.DataFrame(test_negs) + df_test_negs.to_csv(os.path.join(args.output, TEST_NEG_FILENAME), + index=False, header=False, sep='\t') + + +if __name__ == '__main__': + main() diff --git a/examples/ncf/dataset.py b/examples/ncf/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f69fd6ceca410c94fb382aa8889e8e118d2ee985 --- /dev/null +++ b/examples/ncf/dataset.py @@ -0,0 +1,132 @@ +import numpy as np +import scipy +import scipy.sparse +import torch +import torch.utils.data +import subprocess +import time +from tqdm import tqdm +import os +import pickle +import logging + +msglogger = logging.getLogger() + + +def wccount(filename): + out = subprocess.Popen(['wc', '-l', filename], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ).communicate()[0] + return int(out.partition(b' ')[0]) + + +class TimingContext(object): + def __init__(self, desc): + self.desc = desc + + def __enter__(self): + msglogger.info(self.desc + ' ... ') + self.start = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end = time.time() + msglogger.info('Done in {0:.4f} seconds'.format(end - self.start)) + return True + + +class CFTrainDataset(torch.utils.data.dataset.Dataset): + def __init__(self, train_fname, nb_neg): + self._load_train_matrix(train_fname) + self.nb_neg = nb_neg + + def _load_train_matrix(self, train_fname): + pkl_name = os.path.splitext(train_fname)[0] + '_data.pkl' + npz_name = os.path.splitext(train_fname)[0] + '_mat.npz' + + if os.path.isfile(pkl_name) and os.path.isfile(npz_name): + msglogger.info('Found saved dataset data structures') + with TimingContext('Loading data list pickle'), open(pkl_name, 'rb') as f: + self.data = pickle.load(f) + with TimingContext('Loading matrix npz'): + self.mat = scipy.sparse.dok_matrix(scipy.sparse.load_npz(npz_name)) + self.nb_users = self.mat.shape[0] + self.nb_items = self.mat.shape[1] + else: + def process_line(line): + tmp = line.split('\t') + return [int(tmp[0]), int(tmp[1]), float(tmp[2]) > 0] + + with TimingContext('Loading CSV file'), open(train_fname, 'r') as file: + data = list(map(process_line, tqdm(file, total=wccount(train_fname)))) + + with TimingContext('Calculating min/max'): + self.nb_users = max(data, key=lambda x: x[0])[0] + 1 + self.nb_items = max(data, key=lambda x: x[1])[1] + 1 + + with TimingContext('Constructing data list'): + self.data = list(filter(lambda x: x[2], data)) + + with TimingContext('Saving data list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(self.data, f) + + with TimingContext('Building dok matrix'): + self.mat = scipy.sparse.dok_matrix( + (self.nb_users, self.nb_items), dtype=np.float32) + for user, item, _ in tqdm(data): + self.mat[user, item] = 1. + + with TimingContext('Converting to COO matrix and saving'): + scipy.sparse.save_npz(npz_name, self.mat.tocoo(copy=True)) + + def __len__(self): + return (self.nb_neg + 1) * len(self.data) + + def __getitem__(self, idx): + if idx % (self.nb_neg + 1) == 0: + idx = idx // (self.nb_neg + 1) + return self.data[idx][0], self.data[idx][1], np.ones(1, dtype=np.float32) # noqa: E501 + else: + idx = idx // (self.nb_neg + 1) + u = self.data[idx][0] + j = torch.LongTensor(1).random_(0, self.nb_items).item() + while (u, j) in self.mat: + j = torch.LongTensor(1).random_(0, self.nb_items).item() + return u, j, np.zeros(1, dtype=np.float32) + + +def load_test_ratings(fname): + pkl_name = os.path.splitext(fname)[0] + '.pkl' + if os.path.isfile(pkl_name): + with TimingContext('Found test rating pickle file - loading'), open(pkl_name, 'rb') as f: + res = pickle.load(f) + else: + def process_line(line): + tmp = map(int, line.split('\t')[0:2]) + return list(tmp) + with TimingContext('Loading test ratings from csv'), open(fname, 'r') as f: + ratings = map(process_line, tqdm(f, total=wccount(fname))) + res = list(ratings) + with TimingContext('Saving test ratings list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(res, f) + + return res + + +def load_test_negs(fname): + pkl_name = os.path.splitext(fname)[0] + '.pkl' + if os.path.isfile(pkl_name): + with TimingContext('Found test negatives pickle file - loading'), open(pkl_name, 'rb') as f: + res = pickle.load(f) + else: + def process_line(line): + tmp = map(int, line.split('\t')) + return list(tmp) + with TimingContext('Loading test negatives from csv'), open(fname, 'r') as f: + negs = map(process_line, tqdm(f, total=wccount(fname))) + res = list(negs) + with TimingContext('Saving test negatives list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(res, f) + + return res diff --git a/examples/ncf/download_dataset.sh b/examples/ncf/download_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..8876230f14123358e895595d0126cd9537733908 --- /dev/null +++ b/examples/ncf/download_dataset.sh @@ -0,0 +1,16 @@ +function download_20m { + echo "Download ml-20m" + curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip +} + +function download_1m { + echo "Downloading ml-1m" + curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip +} + +if [[ $1 == "ml-1m" ]] +then + download_1m +else + download_20m +fi diff --git a/examples/ncf/load.py b/examples/ncf/load.py new file mode 100644 index 0000000000000000000000000000000000000000..304f43c2bf9836e212a29fa29fb8820320b460d6 --- /dev/null +++ b/examples/ncf/load.py @@ -0,0 +1,68 @@ +from collections import namedtuple + +import pandas as pd + + +RatingData = namedtuple('RatingData', + ['items', 'users', 'ratings', 'min_date', 'max_date']) + + +def describe_ratings(ratings): + info = RatingData(items=len(ratings['item_id'].unique()), + users=len(ratings['user_id'].unique()), + ratings=len(ratings), + min_date=ratings['timestamp'].min(), + max_date=ratings['timestamp'].max()) + print("{ratings} ratings on {items} items from {users} users" + " from {min_date} to {max_date}" + .format(**(info._asdict()))) + return info + + +def process_movielens(ratings, sort=True): + ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s') + if sort: + ratings.sort_values(by='timestamp', inplace=True) + describe_ratings(ratings) + return ratings + + +def load_ml_100k(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='\t', names=names) + return process_movielens(ratings, sort=sort) + + +def load_ml_1m(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='::', names=names, engine='python') + return process_movielens(ratings, sort=sort) + + +def load_ml_10m(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='::', names=names, engine='python') + return process_movielens(ratings, sort=sort) + + +def load_ml_20m(filename, sort=True): + ratings = pd.read_csv(filename) + ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s') + names = {'userId': 'user_id', 'movieId': 'item_id'} + ratings.rename(columns=names, inplace=True) + return process_movielens(ratings, sort=sort) + + +DATASETS = [k.replace('load_', '') for k in locals().keys() if "load_" in k] + + +def get_dataset_name(filename): + for dataset in DATASETS: + if dataset in filename.replace('-', '_').lower(): + return dataset + raise NotImplementedError + + +def implicit_load(filename, sort=True): + func = globals()["load_" + get_dataset_name(filename)] + return func(filename, sort=sort) diff --git a/examples/ncf/logging.conf b/examples/ncf/logging.conf new file mode 100755 index 0000000000000000000000000000000000000000..8db92a75fccb779dc2c02b8e7c668d3cf24363c4 --- /dev/null +++ b/examples/ncf/logging.conf @@ -0,0 +1,38 @@ +[formatters] +keys: simple, time_simple + +[handlers] +keys: console, file + +[loggers] +keys: root, app_cfg + +[formatter_simple] +format: %(message)s + +[formatter_time_simple] +format: %(asctime)s - %(message)s + +[handler_console] +class: StreamHandler +propagate: 0 +args: [] +formatter: simple + +[handler_file] +class: FileHandler +mode: 'w' +args=('%(logfilename)s', 'w') +formatter: time_simple + +[logger_root] +level: INFO +propagate: 1 +handlers: console, file + +[logger_app_cfg] +# Use this logger to log the application configuration and execution environment +level: DEBUG +qualname: app_cfg +propagate: 0 +handlers: file diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py new file mode 100644 index 0000000000000000000000000000000000000000..98376ca825ae800cbfbc8aeae08df176b765d174 --- /dev/null +++ b/examples/ncf/ncf.py @@ -0,0 +1,471 @@ +import os +import heapq +import math +import time +from functools import partial +from datetime import datetime +from collections import OrderedDict +from argparse import ArgumentParser +import sys + +import tqdm +import numpy as np +import torch +import torch.nn as nn +from torch import multiprocessing as mp + +import utils +from neumf import NeuMF +from dataset import CFTrainDataset, load_test_ratings, load_test_negs +from convert import (TEST_NEG_FILENAME, TEST_RATINGS_FILENAME, + TRAIN_RATINGS_FILENAME) + +import distiller +import distiller.quantization as quantization +import distiller.apputils as apputils +from distiller.data_loggers import TensorBoardLogger, PythonLogger + +msglogger = None + + +def parse_args(): + parser = ArgumentParser(description="Train a Nerual Collaborative" + " Filtering model") + parser.add_argument('data', type=str, + help='path to test and training data files') + parser.add_argument('-e', '--epochs', type=int, default=20, + help='number of epochs for training') + parser.add_argument('-b', '--batch-size', type=int, default=256, + help='number of examples for each iteration') + parser.add_argument('-f', '--factors', type=int, default=8, + help='number of predictive factors') + parser.add_argument('--layers', nargs='+', type=int, + default=[64, 32, 16, 8], + help='size of hidden layers for MLP') + parser.add_argument('-n', '--negative-samples', type=int, default=4, + help='number of negative examples per interaction') + parser.add_argument('-l', '--learning-rate', type=float, default=0.001, + help='learning rate for optimizer') + parser.add_argument('-k', '--topk', type=int, default=10, + help='rank for test examples to be considered a hit') + parser.add_argument('--no-cuda', action='store_true', + help='use available GPUs') + parser.add_argument('--seed', '-s', type=int, + help='manually set random seed for torch') + parser.add_argument('--threshold', '-t', type=float, + help='stop training early at threshold') + parser.add_argument('--processes', '-p', type=int, default=1, + help='Number of processes for evaluating model') + parser.add_argument('--workers', '-w', type=int, default=8, + help='Number of workers for training DataLoader') + + # Distiller Args + # summary_choices = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] + # parser.add_argument('--summary', type=str, choices=summary_choices, + # help='print a summary of the model, and exit - options: ' + + # ' | '.join(summary_choices)) + parser.add_argument('--load', type=str, metavar='PATH') + parser.add_argument('--reset-optimizer', action='store_true') + parser.add_argument('--eval', '--evaluate', action='store_true') + parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', + help='configuration file for pruning the model (default is to use hard-coded schedule)') + parser.add_argument('--gpus', metavar='DEV_ID', default=None, + help='Comma-separated list of GPU device IDs to be used ' + '(default is to use all available devices)') + parser.add_argument('--out-dir', '-o', dest='output_dir', default=os.path.join('run', 'neumf'), + help='Path to dump logs and checkpoints') + parser.add_argument('--name', metavar='NAME', default=None, help='Experiment name') + parser.add_argument('--log-freq', '--lf', default=100, type=int, metavar='N', help='Logging frequency') + parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, + help='log the parameter tensors histograms to file ' + '(WARNING: this can use significant disk space)') + parser.add_argument('--split-final', '--sf', action='store_true') + parser.add_argument('--eval-fp16', action='store_true') + parser.add_argument('--activation-histograms', '--act-hist', + type=distiller.utils.float_range_argparse_checker(exc_min=True), + metavar='PORTION_OF_TEST_SET', + help='Run the model in evaluation mode on the specified portion of the test dataset and ' + 'generate activation histograms. NOTE: This slows down evaluation significantly') + quantization.add_post_train_quant_args(parser) + + return parser.parse_args() + + +def predict(model, users, items, batch_size=1024, use_cuda=True): + with torch.no_grad(): + batches = [(users[i:i + batch_size], items[i:i + batch_size]) + for i in range(0, len(users), batch_size)] + preds = [] + for user, item in batches: + def proc(x): + x = np.array(x) + x = torch.from_numpy(x) + if use_cuda: + x = x.cuda(async=True) + return torch.autograd.Variable(x) + outp = model(proc(user), proc(item), torch.tensor([True], dtype=torch.bool)) + outp = outp.data.cpu().numpy() + preds += list(outp.flatten()) + return preds + + +def _calculate_hit(ranked, test_item): + return int(test_item in ranked) + + +def _calculate_ndcg(ranked, test_item): + for i, item in enumerate(ranked): + if item == test_item: + return math.log(2) / math.log(i + 2) + return 0. + + +def eval_one(rating, items, model, K, use_cuda=True): + user = rating[0] + test_item = rating[1] + items.append(test_item) + # items.insert(0, test_item) + users = [user] * len(items) + predictions = predict(model, users, items, use_cuda=use_cuda) + + map_item_score = {item: pred for item, pred in zip(items, predictions)} + ranked = heapq.nlargest(K, map_item_score, key=map_item_score.get) + + hit = _calculate_hit(ranked, test_item) + ndcg = _calculate_ndcg(ranked, test_item) + # return user, hit, ndcg + return hit, ndcg + + +def val_epoch(model, ratings, negs, K, use_cuda=True, output=None, epoch=None, + processes=1, num_users=-1): + if epoch is None: + msglogger.info("Initial evaluation") + else: + msglogger.info("Epoch {} evaluation".format(epoch)) + start = datetime.now() + model.eval() + + if num_users > 0: + ratings = ratings[:num_users] + negs = negs[:num_users] + + if processes > 1: + context = mp.get_context('spawn') + _eval_one = partial(eval_one, model=model, K=K, use_cuda=use_cuda) + with context.Pool(processes=processes) as workers: + hits_and_ndcg = workers.starmap(_eval_one, zip(ratings, negs)) + hits, ndcgs = zip(*hits_and_ndcg) + else: + hits, ndcgs = [], [] + with tqdm.tqdm(zip(ratings, negs), total=len(ratings)) as t: + for rating, items in t: + hit, ndcg = eval_one(rating, items, model, K, use_cuda=use_cuda) + hits.append(hit) + ndcgs.append(ndcg) + steps_completed = len(hits) + 1 + if steps_completed % 100 == 0: + t.set_description('HR@10 = {0:.4f}, NDCG = {1:.4f}'.format(np.mean(hits), np.mean(ndcgs))) + + hits = np.array(hits, dtype=np.float32) + ndcgs = np.array(ndcgs, dtype=np.float32) + + end = datetime.now() + if output is not None: + result = OrderedDict() + result['timestamp'] = datetime.now() + result['duration'] = end - start + result['epoch'] = epoch + result['K'] = K + result['hit_rate'] = np.mean(hits) + result['NDCG'] = np.mean(ndcgs) + utils.save_result(result, output) + + return hits, ndcgs + + +def main(): + global msglogger + + script_dir = os.path.dirname(__file__) + module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) + + args = parse_args() + + # Distiller loggers + msglogger = apputils.config_pylogger('logging.conf', args.name, output_dir=args.output_dir) + tflogger = TensorBoardLogger(msglogger.logdir) + # tflogger.log_gradients = True + # pylogger = PythonLogger(msglogger) + + if args.seed is not None: + msglogger.info("Using seed = {}".format(args.seed)) + torch.manual_seed(args.seed) + np.random.seed(seed=args.seed) + + args.qe_mode = str(args.qe_mode).split('.')[1] + args.qe_clip_acts = str(args.qe_clip_acts).split('.')[1] + + apputils.log_execution_env_state(sys.argv, gitroot=module_path) + + if args.gpus is not None: + try: + args.gpus = [int(s) for s in args.gpus.split(',')] + except ValueError: + msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') + exit(1) + if len(args.gpus) > 1: + msglogger.error('ERROR: Only single GPU supported for NCF') + exit(1) + available_gpus = torch.cuda.device_count() + for dev_id in args.gpus: + if dev_id >= available_gpus: + msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' + .format(dev_id, available_gpus)) + exit(1) + # Set default device in case the first one on the list != 0 + torch.cuda.set_device(args.gpus[0]) + + # Save configuration to file + config = {k: v for k, v in args.__dict__.items()} + config['timestamp'] = "{:.0f}".format(datetime.utcnow().timestamp()) + config['local_timestamp'] = str(datetime.now()) + run_dir = msglogger.logdir + msglogger.info("Saving config and results to {}".format(run_dir)) + if not os.path.exists(run_dir) and run_dir != '': + os.makedirs(run_dir) + utils.save_config(config, run_dir) + + # Check that GPUs are actually available + use_cuda = not args.no_cuda and torch.cuda.is_available() + + t1 = time.time() + # Load Data + training = not (args.eval or args.qe_calibration or args.activation_histograms) + msglogger.info('Loading data') + if training: + train_dataset = CFTrainDataset( + os.path.join(args.data, TRAIN_RATINGS_FILENAME), args.negative_samples) + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True) + nb_users, nb_items = train_dataset.nb_users, train_dataset.nb_items + else: + train_dataset = None + train_dataloader = None + nb_users, nb_items = (138493, 26744) + + test_ratings = load_test_ratings(os.path.join(args.data, TEST_RATINGS_FILENAME)) # noqa: E501 + test_negs = load_test_negs(os.path.join(args.data, TEST_NEG_FILENAME)) + + msglogger.info('Load data done [%.1f s]. #user=%d, #item=%d, #train=%s, #test=%d' + % (time.time()-t1, nb_users, nb_items, str(train_dataset.mat.nnz) if training else 'N/A', + len(test_ratings))) + + # Create model + model = NeuMF(nb_users, nb_items, + mf_dim=args.factors, mf_reg=0., + mlp_layer_sizes=args.layers, + mlp_layer_regs=[0. for i in args.layers], + split_final=args.split_final) + if use_cuda: + model = model.cuda() + msglogger.info(model) + msglogger.info("{} parameters".format(utils.count_parameters(model))) + + # Save model text description + with open(os.path.join(run_dir, 'model.txt'), 'w') as file: + file.write(str(model)) + + compression_scheduler = None + start_epoch = 0 + optimizer = None + if args.load: + if training: + model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(model, args.load) + if args.reset_optimizer: + start_epoch = 0 + optimizer = None + else: + model = apputils.load_lean_checkpoint(model, args.load) + + # Add loss to graph + criterion = nn.BCEWithLogitsLoss() + + if use_cuda: + criterion = criterion.cuda() + + if training and optimizer is None: + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + msglogger.info('Optimizer Type: %s', type(optimizer)) + msglogger.info('Optimizer Args: %s', optimizer.defaults) + + if args.compress: + compression_scheduler = distiller.file_config(model, optimizer, args.compress) + model.cuda() + + # Create files for tracking training + valid_results_file = os.path.join(run_dir, 'valid_results.csv') + + if args.qe_calibration or args.activation_histograms: + calib = {'portion': args.qe_calibration, + 'desc_str': 'quantization calibration stats', + 'collect_func': partial(distiller.data_loggers.collect_quant_stats, inplace_runtime_check=True, + disable_inplace_attrs=True)} + hists = {'portion': args.activation_histograms, + 'desc_str': 'activation histograms', + 'collect_func': partial(distiller.data_loggers.collect_histograms, activation_stats=None, nbins=2048, + save_hist_imgs=True)} + d = calib if args.qe_calibration else hists + + distiller.utils.assign_layer_fq_names(model) + num_users = int(np.floor(len(test_ratings) * d['portion'])) + msglogger.info( + "Generating {} based on {:.1%} of the test-set ({} users)".format(d['desc_str'], d['portion'], num_users)) + + test_fn = partial(val_epoch, ratings=test_ratings, negs=test_negs, K=args.topk, use_cuda=use_cuda, + processes=args.processes, num_users=num_users) + d['collect_func'](model=model, test_fn=test_fn, save_dir=run_dir, classes=None) + + return 0 + + if args.eval: + if args.quantize_eval and args.qe_calibration is None: + model.cpu() + quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) + dummy_input = (torch.tensor([1]), torch.tensor([1]), torch.tensor([True], dtype=torch.bool)) + quantizer.prepare_model(dummy_input) + model.cuda() + + distiller.utils.assign_layer_fq_names(model) + + if args.eval_fp16: + model = model.half() + + # Calculate initial Hit Ratio and NDCG + begin = time.time() + hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk, + use_cuda=use_cuda, processes=args.processes) + val_time = time.time() - begin + hit_rate = np.mean(hits) + msglogger.info('Initial HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, val_time = {val_time:.2f}' + .format(K=args.topk, hit_rate=hit_rate, ndcg=np.mean(ndcgs), val_time=val_time)) + hit_rate = 0 + + if args.quantize_eval: + checkpoint_name = 'quantized' + apputils.save_checkpoint(0, 'NCF', model, optimizer=None, extras={'quantized_hr@10': hit_rate}, + name='_'.join([args.name, 'quantized']) if args.name else checkpoint_name, + dir=msglogger.logdir) + return 0 + + total_samples = len(train_dataloader.sampler) + steps_per_epoch = math.ceil(total_samples / args.batch_size) + best_hit_rate = 0 + best_epoch = 0 + for epoch in range(start_epoch, args.epochs): + msglogger.info('') + model.train() + losses = utils.AverageMeter() + + begin = time.time() + + if compression_scheduler: + compression_scheduler.on_epoch_begin(epoch, optimizer) + + loader = tqdm.tqdm(train_dataloader) + for batch_index, (user, item, label) in enumerate(loader): + user = torch.autograd.Variable(user, requires_grad=False) + item = torch.autograd.Variable(item, requires_grad=False) + label = torch.autograd.Variable(label, requires_grad=False) + if use_cuda: + user = user.cuda(async=True) + item = item.cuda(async=True) + label = label.cuda(async=True) + + if compression_scheduler: + compression_scheduler.on_minibatch_begin(epoch, batch_index, steps_per_epoch, optimizer) + + outputs = model(user, item, torch.tensor([False], dtype=torch.bool)) + loss = criterion(outputs, label) + + if compression_scheduler: + compression_scheduler.before_backward_pass(epoch, batch_index, steps_per_epoch, loss, optimizer, + return_loss_components=False) + + losses.update(loss.data.item(), user.size(0)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if compression_scheduler: + compression_scheduler.on_minibatch_end(epoch, batch_index, steps_per_epoch, optimizer) + + # Save stats to file + description = ('Epoch {} Loss {loss.val:.4f} ({loss.avg:.4f})' + .format(epoch, loss=losses)) + loader.set_description(description) + + steps_completed = batch_index + 1 + if steps_completed % args.log_freq == 0: + stats_dict = OrderedDict() + stats_dict['Loss'] = losses.avg + stats = ('Performance/Training/', stats_dict) + params = model.named_parameters() if args.log_params_histograms else None + distiller.log_training_progress(stats, params, epoch, steps_completed, steps_per_epoch, args.log_freq, + [tflogger]) + + tflogger.log_model_buffers(model, ['tracked_min', 'tracked_max'], 'Quant/Train/Acts/TrackedMinMax', + epoch, steps_completed, steps_per_epoch, args.log_freq) + + train_time = time.time() - begin + begin = time.time() + hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk, + use_cuda=use_cuda, output=valid_results_file, + epoch=epoch, processes=args.processes) + val_time = time.time() - begin + + if compression_scheduler: + compression_scheduler.on_epoch_end(epoch, optimizer) + + hit_rate = np.mean(hits) + mean_ndcgs = np.mean(ndcgs) + + stats_dict = OrderedDict() + stats_dict['HR@{0}'.format(args.topk)] = hit_rate + stats_dict['NDCG@{0}'.format(args.topk)] = mean_ndcgs + stats = ('Performance/Validation/', stats_dict) + distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, + loggers=[tflogger]) + + msglogger.info('Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, AvgTrainLoss = {loss.avg:.4f}, ' + 'train_time = {train_time:.2f}, val_time = {val_time:.2f}'.format( + epoch=epoch, K=args.topk, hit_rate=hit_rate, ndcg=mean_ndcgs, + loss=losses, train_time=train_time, val_time=val_time)) + + is_best = False + if hit_rate > best_hit_rate: + best_hit_rate = hit_rate + is_best = True + best_epoch = epoch + extras = {'current_hr@10': hit_rate, + 'best_hr@10': best_hit_rate, + 'best_epoch': best_epoch} + apputils.save_checkpoint(epoch, 'NCF', model, optimizer, compression_scheduler, extras, is_best, dir=run_dir) + + if args.threshold is not None: + if np.mean(hits) >= args.threshold: + msglogger.info("Hit threshold of {}".format(args.threshold)) + break + + +if __name__ == '__main__': + try: + main() + except KeyboardInterrupt: + print("\n-- KeyboardInterrupt --") + finally: + if msglogger is not None: + msglogger.info('') + msglogger.info('Log file for this run: ' + os.path.realpath(msglogger.log_filename)) diff --git a/examples/ncf/neumf.py b/examples/ncf/neumf.py new file mode 100644 index 0000000000000000000000000000000000000000..c6179f00a23ebb7c3354228a5d554191d8e1f638 --- /dev/null +++ b/examples/ncf/neumf.py @@ -0,0 +1,120 @@ +import numpy as np +import torch +import torch.nn as nn + +import distiller.modules + +import logging +import os +msglogger = logging.getLogger() + + +class NeuMF(nn.Module): + def __init__(self, nb_users, nb_items, + mf_dim, mf_reg, + mlp_layer_sizes, mlp_layer_regs, split_final=False): + if len(mlp_layer_sizes) != len(mlp_layer_regs): + raise RuntimeError('u dummy, layer_sizes != layer_regs!') + if mlp_layer_sizes[0] % 2 != 0: + raise RuntimeError('u dummy, mlp_layer_sizes[0] % 2 != 0') + super(NeuMF, self).__init__() + + self.mf_dim = mf_dim + self.mlp_layer_sizes = mlp_layer_sizes + + nb_mlp_layers = len(mlp_layer_sizes) + + # TODO: regularization? + self.mf_user_embed = nn.Embedding(nb_users, mf_dim) + self.mf_item_embed = nn.Embedding(nb_items, mf_dim) + self.mlp_user_embed = nn.Embedding(nb_users, mlp_layer_sizes[0] // 2) + self.mlp_item_embed = nn.Embedding(nb_items, mlp_layer_sizes[0] // 2) + + self.mf_mult = distiller.modules.EltwiseMult() + self.mlp_concat = distiller.modules.Concat(dim=1) + + self.mlp = nn.ModuleList() + self.mlp_relu = nn.ModuleList() + for i in range(1, nb_mlp_layers): + self.mlp.extend([nn.Linear(mlp_layer_sizes[i - 1], mlp_layer_sizes[i])]) # noqa: E501 + self.mlp_relu.extend([nn.ReLU()]) + + self.split_final = split_final + if not split_final: + self.final_concat = distiller.modules.Concat(dim=1) + self.final = nn.Linear(mlp_layer_sizes[-1] + mf_dim, 1) + else: + self.final_mlp = nn.Linear(mlp_layer_sizes[-1], 1) + self.final_mf = nn.Linear(mf_dim, 1) + self.final_add = distiller.modules.EltwiseAdd() + + self.sigmoid = nn.Sigmoid() + + self.mf_user_embed.weight.data.normal_(0., 0.01) + self.mf_item_embed.weight.data.normal_(0., 0.01) + self.mlp_user_embed.weight.data.normal_(0., 0.01) + self.mlp_item_embed.weight.data.normal_(0., 0.01) + + def golorot_uniform(layer): + fan_in, fan_out = layer.in_features, layer.out_features + limit = np.sqrt(6. / (fan_in + fan_out)) + layer.weight.data.uniform_(-limit, limit) + + def lecunn_uniform(layer): + fan_in, fan_out = layer.in_features, layer.out_features # noqa: F841, E501 + limit = np.sqrt(3. / fan_in) + layer.weight.data.uniform_(-limit, limit) + + for layer in self.mlp: + if type(layer) != nn.Linear: + continue + golorot_uniform(layer) + if not split_final: + lecunn_uniform(self.final) + else: + lecunn_uniform(self.final_mlp) + lecunn_uniform(self.final_mf) + + def load_state_dict(self, state_dict, strict=True): + if 'final.weight' in state_dict and self.split_final: + # Loading no-split checkpoint into split model + + # MF weights come first, then MLP + final_weight = state_dict.pop('final.weight') + state_dict['final_mf.weight'] = final_weight[0][:self.mf_dim].unsqueeze(0) + state_dict['final_mlp.weight'] = final_weight[0][self.mf_dim:].unsqueeze(0) + + # Split bias 50-50 + final_bias = state_dict.pop('final.bias') + state_dict['final_mf.bias'] = final_bias * 0.5 + state_dict['final_mlp.bias'] = final_bias * 0.5 + elif 'final_mf.weight' in state_dict and not self.split_final: + # Loading split checkpoint into no-split model + state_dict['final.weight'] = torch.cat((state_dict.pop('final_mf.weight')[0], + state_dict.pop('final_mlp.weight')[0])).unsqueeze(0) + state_dict['final.bias'] = state_dict.pop('final_mf.bias') + state_dict.pop('final_mlp.bias') + + super(NeuMF, self).load_state_dict(state_dict, strict) + + def forward(self, user, item, sigmoid): + xmfu = self.mf_user_embed(user) + xmfi = self.mf_item_embed(item) + xmf = self.mf_mult(xmfu, xmfi) + + xmlpu = self.mlp_user_embed(user) + xmlpi = self.mlp_item_embed(item) + xmlp = self.mlp_concat(xmlpu, xmlpi) + for i, (layer, act) in enumerate(zip(self.mlp, self.mlp_relu)): + xmlp = layer(xmlp) + xmlp = act(xmlp) + + if not self.split_final: + x = self.final_concat(xmf, xmlp) + x = self.final(x) + else: + xmf = self.final_mf(xmf) + xmlp = self.final_mlp(xmlp) + x = self.final_add(xmf, xmlp) + if sigmoid: + x = self.sigmoid(x) + return x diff --git a/examples/ncf/quantization_stats_no_split.yaml b/examples/ncf/quantization_stats_no_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f32232c563dce72256d95b8a9e36363d5fe1abf4 --- /dev/null +++ b/examples/ncf/quantization_stats_no_split.yaml @@ -0,0 +1,294 @@ +mf_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) +mf_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) +mlp_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) +mlp_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) +mf_mult: + inputs: + 0: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) + 1: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) +mlp_concat: + inputs: + 0: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) + 1: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) +mlp.0: + inputs: + 0: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) + output: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) +mlp.1: + inputs: + 0: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) + output: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) +mlp.2: + inputs: + 0: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) + output: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) +mlp_relu.0: + inputs: + 0: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) + output: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) +mlp_relu.1: + inputs: + 0: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) + output: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) +mlp_relu.2: + inputs: + 0: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) + output: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) +final_concat: + inputs: + 0: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) + 1: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 203.7071990966797 + avg_min: -1.4091147140708975 + avg_max: 29.95700872795735 + mean: 2.452035934175203 + std: 8.46057860792389 + shape: (1000, 128) +final: + inputs: + 0: + min: -6.388758659362793 + max: 203.7071990966797 + avg_min: -1.4091147140708975 + avg_max: 29.95700872795735 + mean: 2.452035934175203 + std: 8.46057860792389 + shape: (1000, 128) + output: + min: -264.23663330078125 + max: 10.719743728637695 + avg_min: -64.09727749207161 + avg_max: 4.514405594118789 + mean: -27.331936557333087 + std: 29.674832823876194 + shape: (1000, 1) +sigmoid: + inputs: + 0: + min: -264.23663330078125 + max: 10.719743728637695 + avg_min: -64.09727749207161 + avg_max: 4.514405594118789 + mean: -27.331936557333087 + std: 29.674832823876194 + shape: (1000, 1) + output: + min: 0.0 + max: 0.9999779462814331 + avg_min: 7.22119551814259e-08 + avg_max: 0.9796236092019589 + mean: 0.025780337072657727 + std: 0.11967490732565136 + shape: (1000, 1) diff --git a/examples/ncf/quantization_stats_split.yaml b/examples/ncf/quantization_stats_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afbc00409ba2d336701d5a21095b67c68400db21 --- /dev/null +++ b/examples/ncf/quantization_stats_split.yaml @@ -0,0 +1,312 @@ +mf_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) +mf_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) +mlp_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) +mlp_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) +mf_mult: + inputs: + 0: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) + 1: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) +mlp_concat: + inputs: + 0: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) + 1: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) +mlp.0: + inputs: + 0: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) + output: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) +mlp.1: + inputs: + 0: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) + output: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) +mlp.2: + inputs: + 0: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) + output: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) +mlp_relu.0: + inputs: + 0: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) + output: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) +mlp_relu.1: + inputs: + 0: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) + output: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) +mlp_relu.2: + inputs: + 0: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) + output: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) +final_mlp: + inputs: + 0: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) + output: + min: -283.5218200683594 + max: 40.86720275878906 + avg_min: -60.916563263919905 + avg_max: 6.9745166829700596 + mean: -22.90735553673859 + std: 35.900340843176544 + shape: (1000, 1) +final_mf: + inputs: + 0: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) + output: + min: -54.07410430908203 + max: 47.101890563964844 + avg_min: -16.45866186288787 + avg_max: 7.688797266098178 + mean: -4.42458063103854 + std: 9.54055961838881 + shape: (1000, 1) +final_add: + inputs: + 0: + min: -54.07410430908203 + max: 47.101890563964844 + avg_min: -16.45866186288787 + avg_max: 7.688797266098178 + mean: -4.42458063103854 + std: 9.54055961838881 + shape: (1000, 1) + 1: + min: -283.5218200683594 + max: 40.86720275878906 + avg_min: -60.916563263919905 + avg_max: 6.9745166829700596 + mean: -22.90735553673859 + std: 35.900340843176544 + shape: (1000, 1) + output: + min: -264.23663330078125 + max: 10.719744682312012 + avg_min: -64.09727644866952 + avg_max: 4.514405736883009 + mean: -27.33193622736199 + std: 29.674832823876194 + shape: (1000, 1) +sigmoid: + inputs: + 0: + min: -264.23663330078125 + max: 10.719744682312012 + avg_min: -64.09727644866952 + avg_max: 4.514405736883009 + mean: -27.33193622736199 + std: 29.674832823876194 + shape: (1000, 1) + output: + min: 0.0 + max: 0.9999779462814331 + avg_min: 7.221196356095192e-08 + avg_max: 0.9796236107793396 + mean: 0.025780336994990046 + std: 0.11967490732565136 + shape: (1000, 1) diff --git a/examples/ncf/utils.py b/examples/ncf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4395830c050b54fbfc4212dff4bbf5d3c755ed00 --- /dev/null +++ b/examples/ncf/utils.py @@ -0,0 +1,41 @@ +import os +import json +from functools import reduce + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def count_parameters(model): + c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters()) + return sum(c) + + +def save_config(config, run_dir): + path = os.path.join(run_dir, "config_{}.json".format(config['timestamp'])) + with open(path, 'w') as config_file: + json.dump(config, config_file) + config_file.write('\n') + + +def save_result(result, path): + write_heading = not os.path.exists(path) + with open(path, mode='a') as out: + if write_heading: + out.write(",".join([str(k) for k, v in result.items()]) + '\n') + out.write(",".join([str(v) for k, v in result.items()]) + '\n') diff --git a/examples/ncf/verify_dataset.sh b/examples/ncf/verify_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..208d7602a8fad8bf3151edf90e8ecd2641195e14 --- /dev/null +++ b/examples/ncf/verify_dataset.sh @@ -0,0 +1,44 @@ +function get_checker { + if [[ "$OSTYPE" == "darwin"* ]]; then + checkmd5=md5 + else + checkmd5=md5sum + fi + + echo $checkmd5 +} + + +function verify_1m { + # From: curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip.md5 + hash=<(echo "MD5 (ml-1m.zip) = c4d9eecfca2ab87c1945afe126590906") + local checkmd5=$(get_checker) + if diff <($checkmd5 ml-1m.zip) $hash &> /dev/null + then + echo "PASSED" + else + echo "FAILED" + fi +} + +function verify_20m { + # From: curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip.md5 + hash=<(echo "MD5 (ml-20m.zip) = cd245b17a1ae2cc31bb14903e1204af3") + local checkmd5=$(get_checker) + + if diff <($checkmd5 ml-20m.zip) $hash &> /dev/null + then + echo "PASSED" + else + echo "FAILED" + fi + +} + + +if [[ $1 == "ml-1m" ]] +then + verify_1m +else + verify_20m +fi diff --git a/jupyter/truncated_svd.ipynb b/jupyter/truncated_svd.ipynb index 7331d3c54b228f57364c1d27a3c17fbf9b94ecc3..07f41d657c30f42480b7ef20c58de434aaa1ed71 100644 --- a/jupyter/truncated_svd.ipynb +++ b/jupyter/truncated_svd.ipynb @@ -40,12 +40,54 @@ " - Top1: 75.65 \n", " - Top5: 92.75\n", " \n", - " Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000) " + " Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)\n", + " \n", + "## Details\n", + "\n", + "[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.\n", + "\n", + "A Linear (fully-connected) layer performs: y = Wx + b (or y = xW$^T$ + b)<br>\n", + "We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b\n", + "\n", + "So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough†and also accelerates the computation of Wx.\n", + "\n", + "We choose some lower-rank, k, such that k<m (preferably k<<m).<br>\n", + "TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).\n", + "\n", + "After TSVD we have:\n", + "U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>\n", + "y ~ (U’S’V’)x + b<br>\n", + "y ~ (U’(S’V’))x + b<br>\n", + "\n", + "We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>\n", + "y ~ (U’A)x + b<br>\n", + "y ~ U’(Ax) + b<br>\n", + "\n", + "Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:\n", + "\n", + " - m * n weights coefficients<br>\n", + " - m * n FLOPs (for batch size = 1)<br>\n", + "\n", + "After TSVD we have:\n", + "\n", + " - mk + kn = k*(m+n) weights coefficients<br>\n", + " - kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>\n", + "\n", + "To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>\n", + "Let’s rewrite k in terms of m: k = tm<br>\n", + "m * n > tm*(m+n)<br>\n", + "n > t*(m+n)<br>\n", + "n / (m+n) >= t<br>\n", + "\n", + "This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)\n", + "\n", + "In the example notebook: m = 1000; n=2048<br>\n", + "So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -105,14 +147,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 4\n", "\n", "# Data loader\n", - "test_loader = imagenet_load_data(\"../../data.imagenet/\", \n", + "test_loader = imagenet_load_data(\"/datasets/imagenet/\", \n", " batch_size=BATCH_SIZE, \n", " num_workers=2)\n", " \n", @@ -136,11 +178,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fc_layer(2048, 1000)\n", + "W = torch.Size([1000, 2048])\n", + "U = torch.Size([1000, 400])\n", + "SV = torch.Size([400, 2048])\n" + ] + } + ], "source": [ "# Load the various models\n", "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n", @@ -170,11 +223,11 @@ "\n", "\n", "class TruncatedSVD(nn.Module):\n", - " def __init__(self, replaced_gemm, gemm_weights):\n", - " super(TruncatedSVD,self).__init__()\n", + " def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):\n", + " super().__init__()\n", " self.replaced_gemm = replaced_gemm\n", " print(\"W = {}\".format(gemm_weights.shape))\n", - " self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))\n", + " self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))\n", " print(\"U = {}\".format(self.U.shape))\n", " \n", " self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n", @@ -182,20 +235,21 @@ " \n", " print(\"SV = {}\".format(self.SV.shape))\n", " self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n", - " self.fc_sv.weight.data = self.SV#.t()\n", - " \n", + " self.fc_sv.weight.data = self.SV#.t() \n", "\n", " def forward(self, x):\n", " x = self.fc_sv.forward(x)\n", " x = self.fc_u.forward(x)\n", " return x\n", "\n", + " \n", "def replace(model):\n", " fc_weights = model.state_dict()['fc.weight']\n", " fc_layer = model.fc\n", " print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n", - " model.fc = TruncatedSVD(fc_layer, fc_weights)\n", + " model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)\n", "\n", + " \n", "from copy import deepcopy\n", "resnet50 = deepcopy(resnet50)\n", "replace(resnet50)" @@ -203,18 +257,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "progress: 51200 images\n", + "progress: 102400 images\n", + "progress: 153600 images\n", + "progress: 204800 images\n", + "progress: 256000 images\n", + "progress: 307200 images\n", + "progress: 358400 images\n", + "Top1: 75.70 Top5: 92.76\n", + "Duration: 168.111492395401\n" + ] + } + ], "source": [ "# Standard loop to test the accuracy of a model.\n", "\n", "import time\n", "import torchnet.meter as tnt\n", "t0 = time.time()\n", - "test_loader = imagenet_load_data(\"../../data.imagenet\", \n", + "test_loader = imagenet_load_data(\"/datasets/imagenet\", \n", " batch_size=64, \n", " num_workers=4,\n", " shuffle=False)\n", @@ -234,6 +304,13 @@ "t2 = time.time()\n", "print(\"Duration: \", t2-t0)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/licenses/py_faster_rcnn-license.txt b/licenses/py_faster_rcnn-license.txt new file mode 100755 index 0000000000000000000000000000000000000000..1ab42b27a3ac66e841a94d6f568be493efcde274 --- /dev/null +++ b/licenses/py_faster_rcnn-license.txt @@ -0,0 +1,81 @@ +Faster R-CNN + +The MIT License (MIT) + +Copyright (c) 2015 Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project, Faster R-CNN, incorporates material from the project(s) +listed below (collectively, "Third Party Code"). Microsoft is not the +original author of the Third Party Code. The original copyright notice +and license under which Microsoft received such Third Party Code are set +out below. This Third Party Code is licensed to you under their original +license terms set forth below. Microsoft reserves all other rights not +expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright +over their contributions to Caffe. The project versioning records all +such contribution and copyright details. If a contributor wants to +further mark their specific copyright on a particular contribution, +they should indicate their copyright solely in the commit message of +the change when it is committed. + +The BSD 2-Clause License + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** diff --git a/requirements.txt b/requirements.txt index 06e56c4711c81dce3432bb7ce938c2baea71fe3e..c69c7065753c638ef5c5a77c9d7077fbc6855c43 100755 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ xlsxwriter>=1.1.1 pretrainedmodels==0.7.4 scikit-learn==0.21.2 gym==0.12.5 +tqdm==4.33.0