diff --git a/.gitignore b/.gitignore index ccc482df49900ddb8aea041ed509d3509d9a174f..0fc4bc08f311f284ea86b2123a382e06d5942caf 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ logs/ __pycache__/ .pytest_cache .cache +examples/ncf/run/ +examples/ncf/ml-20m* # Virtual env env/ diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index 06d5d9a31233798b1c4b6112a6d38cc73e0c18e8..e24c97686556d708d111f4b3bc7023bcff5985e9 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -16,7 +16,7 @@ from .quantizer import Quantizer from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \ - LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args,\ + LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \ RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index fb460c045b869223e75159044887e5d211f15bb5..8984ba28160c3c7aedf3d8b56395aca05482ba3c 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -1076,3 +1076,21 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): per_channel=perch) m.register_buffer(ptq.q_attr_name + '_scale', torch.ones_like(scale)) m.register_buffer(ptq.q_attr_name + '_zero_point', torch.zeros_like(zero_point)) + + +class NCFQuantAwareTrainQuantizer(QuantAwareTrainRangeLinearQuantizer): + def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_bias=32, + overrides=None, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False): + super(NCFQuantAwareTrainQuantizer, self).__init__(model, optimizer=optimizer, + bits_activations=bits_activations, + bits_weights=bits_weights, + bits_bias=bits_bias, + overrides=overrides, + mode=mode, ema_decay=ema_decay, + per_channel_wts=per_channel_wts, + quantize_inputs=False) + + self.replacement_factory[distiller.modules.EltwiseMult] = self.activation_replace_fn + self.replacement_factory[distiller.modules.Concat] = self.activation_replace_fn + self.replacement_factory[nn.Linear] = self.activation_replace_fn + # self.replacement_factory[nn.Sigmoid] = self.activation_replace_fn diff --git a/examples/ncf/MLPERF_LICENSE.md b/examples/ncf/MLPERF_LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..89ba5be4833f1b8418c869f63ab8ee1f6e881531 --- /dev/null +++ b/examples/ncf/MLPERF_LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 The MLPerf Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/ncf/README.md b/examples/ncf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fc477a629975fbc7ae012e5ab8dda95ad785f5e8 --- /dev/null +++ b/examples/ncf/README.md @@ -0,0 +1,96 @@ +# NCF - Neural Collaborative Filtering + +The NCF implementation provided here is based on the implementation found in the MLPerf Training GitHub repository, specifically on the last revision of the code before the switch to the extended dataset. See [here](https://github.com/mlperf/training/tree/fe17e837ed12974d15c86d5173fe8f2c188434d5/recommendation/pytorch). + +We've made several modifications to the code: +* Removed all MLPerf specific code including logging +* In `ncf.py`: + * Added calls to Distiller compression APIs + * Added progress indication in training and evaluation flows +* In `neumf.py`: + * Added option to split final FC layer + * Replaced all functional calls with modules so they can be detected by Distiller +* In `dataset.py`: + * Speed up data loading - On first data will is loaded from CSVs and then pickled. On subsequent runs the pickle is loaded. This is much faster than the original implementation, but still very slow. + * Added progress indication during data load process +* Removed some irrelevant content from this README + +## Problem + +This task benchmarks recommendation with implicit feedback on the [MovieLens 20 Million (ml-20m) dataset](https://grouplens.org/datasets/movielens/20m/) with a [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569) model. +The model trains on binary information about whether or not a user interacted with a specific item. + +## Setup + +### Steps to configure machine + +1. Install `unzip` and `curl` + +```bash +sudo apt-get install unzip curl +``` + +2. Install required python packages + +```bash +pip install -r requirements.txt +``` + +3. Download and verify data + +```bash +# Creates ml-20.zip +source ../download_dataset.sh +# Confirms the MD5 checksum of ml-20.zip +source ../verify_dataset.sh +``` + +## Running the Sample + +### TODO: Add some Distiller specific example command line + +## Dataset/Environment + +### Publication/Attribution + +Harper, F. M. & Konstan, J. A. (2015), 'The MovieLens Datasets: History and Context', ACM Trans. Interact. Intell. Syst. 5(4), 19:1--19:19. + +### Data preprocessing + +1. Unzip +2. Remove users with less than 20 reviews +3. Create training and test data separation described below + +### Training and test data separation + +Positive training examples are all but the last item each user rated. +Negative training examples are randomly selected from the unrated items for each user. + +The last item each user rated is used as a positive example in the test set. +A fixed set of 999 unrated items are also selected to calculate hit rate at 10 for predicting the test item. + +### Training data order + +Data is traversed randomly with 4 negative examples selected on average for every positive example. + +## Model + +### Publication/Attribution + +Xiangnan He, Lizi Liao, Hanwang Zhang, Liqiang Nie, Xia Hu and Tat-Seng Chua (2017). [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569). In Proceedings of WWW '17, Perth, Australia, April 03-07, 2017. + +The author's original code is available at [hexiangnan/neural_collaborative_filtering](https://github.com/hexiangnan/neural_collaborative_filtering). + +## Quality + +### Quality metric + +Hit rate at 10 (HR@10) with 999 negative items. + +### Evaluation frequency + +After every epoch through the training data. + +### Evaluation thoroughness + +Every users last item rated, i.e. all held out positive examples. diff --git a/examples/ncf/convert.py b/examples/ncf/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8e5d7ccd33ffaecaa0a63e3fd093386d70c7fb --- /dev/null +++ b/examples/ncf/convert.py @@ -0,0 +1,101 @@ +import os +from argparse import ArgumentParser +from collections import defaultdict + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from load import implicit_load + + +MIN_RATINGS = 20 + + +USER_COLUMN = 'user_id' +ITEM_COLUMN = 'item_id' + + +TRAIN_RATINGS_FILENAME = 'train-ratings.csv' +TEST_RATINGS_FILENAME = 'test-ratings.csv' +TEST_NEG_FILENAME = 'test-negative.csv' + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('path', type=str, + help='Path to reviews CSV file from MovieLens') + parser.add_argument('output', type=str, + help='Output directory for train and test CSV files') + parser.add_argument('-n', '--negatives', type=int, default=999, + help='Number of negative samples for each positive' + 'test example') + parser.add_argument('-s', '--seed', type=int, default=0, + help='Random seed to reproduce same negative samples') + return parser.parse_args() + + +def main(): + args = parse_args() + np.random.seed(args.seed) + + print("Loading raw data from {}".format(args.path)) + df = implicit_load(args.path, sort=False) + print("Filtering out users with less than {} ratings".format(MIN_RATINGS)) + grouped = df.groupby(USER_COLUMN) + df = grouped.filter(lambda x: len(x) >= MIN_RATINGS) + + print("Mapping original user and item IDs to new sequential IDs") + original_users = df[USER_COLUMN].unique() + original_items = df[ITEM_COLUMN].unique() + + user_map = {user: index for index, user in enumerate(original_users)} + item_map = {item: index for index, item in enumerate(original_items)} + + df[USER_COLUMN] = df[USER_COLUMN].apply(lambda user: user_map[user]) + df[ITEM_COLUMN] = df[ITEM_COLUMN].apply(lambda item: item_map[item]) + + assert df[USER_COLUMN].max() == len(original_users) - 1 + assert df[ITEM_COLUMN].max() == len(original_items) - 1 + + print("Creating list of items for each user") + # Need to sort before popping to get last item + df.sort_values(by='timestamp', inplace=True) + all_ratings = set(zip(df[USER_COLUMN], df[ITEM_COLUMN])) + user_to_items = defaultdict(list) + for row in tqdm(df.itertuples(), desc='Ratings', total=len(df)): + user_to_items[getattr(row, USER_COLUMN)].append(getattr(row, ITEM_COLUMN)) # noqa: E501 + + test_ratings = [] + test_negs = [] + all_items = set(range(len(original_items))) + print("Generating {} negative samples for each user" + .format(args.negatives)) + for user in tqdm(range(len(original_users)), desc='Users', total=len(original_users)): # noqa: E501 + test_item = user_to_items[user].pop() + + all_ratings.remove((user, test_item)) + all_negs = all_items - set(user_to_items[user]) + all_negs = sorted(list(all_negs)) # determinism + + test_ratings.append((user, test_item)) + test_negs.append(list(np.random.choice(all_negs, args.negatives))) + + print("Saving train and test CSV files to {}".format(args.output)) + df_train_ratings = pd.DataFrame(list(all_ratings)) + df_train_ratings['fake_rating'] = 1 + df_train_ratings.to_csv(os.path.join(args.output, TRAIN_RATINGS_FILENAME), + index=False, header=False, sep='\t') + + df_test_ratings = pd.DataFrame(test_ratings) + df_test_ratings['fake_rating'] = 1 + df_test_ratings.to_csv(os.path.join(args.output, TEST_RATINGS_FILENAME), + index=False, header=False, sep='\t') + + df_test_negs = pd.DataFrame(test_negs) + df_test_negs.to_csv(os.path.join(args.output, TEST_NEG_FILENAME), + index=False, header=False, sep='\t') + + +if __name__ == '__main__': + main() diff --git a/examples/ncf/dataset.py b/examples/ncf/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f69fd6ceca410c94fb382aa8889e8e118d2ee985 --- /dev/null +++ b/examples/ncf/dataset.py @@ -0,0 +1,132 @@ +import numpy as np +import scipy +import scipy.sparse +import torch +import torch.utils.data +import subprocess +import time +from tqdm import tqdm +import os +import pickle +import logging + +msglogger = logging.getLogger() + + +def wccount(filename): + out = subprocess.Popen(['wc', '-l', filename], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ).communicate()[0] + return int(out.partition(b' ')[0]) + + +class TimingContext(object): + def __init__(self, desc): + self.desc = desc + + def __enter__(self): + msglogger.info(self.desc + ' ... ') + self.start = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end = time.time() + msglogger.info('Done in {0:.4f} seconds'.format(end - self.start)) + return True + + +class CFTrainDataset(torch.utils.data.dataset.Dataset): + def __init__(self, train_fname, nb_neg): + self._load_train_matrix(train_fname) + self.nb_neg = nb_neg + + def _load_train_matrix(self, train_fname): + pkl_name = os.path.splitext(train_fname)[0] + '_data.pkl' + npz_name = os.path.splitext(train_fname)[0] + '_mat.npz' + + if os.path.isfile(pkl_name) and os.path.isfile(npz_name): + msglogger.info('Found saved dataset data structures') + with TimingContext('Loading data list pickle'), open(pkl_name, 'rb') as f: + self.data = pickle.load(f) + with TimingContext('Loading matrix npz'): + self.mat = scipy.sparse.dok_matrix(scipy.sparse.load_npz(npz_name)) + self.nb_users = self.mat.shape[0] + self.nb_items = self.mat.shape[1] + else: + def process_line(line): + tmp = line.split('\t') + return [int(tmp[0]), int(tmp[1]), float(tmp[2]) > 0] + + with TimingContext('Loading CSV file'), open(train_fname, 'r') as file: + data = list(map(process_line, tqdm(file, total=wccount(train_fname)))) + + with TimingContext('Calculating min/max'): + self.nb_users = max(data, key=lambda x: x[0])[0] + 1 + self.nb_items = max(data, key=lambda x: x[1])[1] + 1 + + with TimingContext('Constructing data list'): + self.data = list(filter(lambda x: x[2], data)) + + with TimingContext('Saving data list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(self.data, f) + + with TimingContext('Building dok matrix'): + self.mat = scipy.sparse.dok_matrix( + (self.nb_users, self.nb_items), dtype=np.float32) + for user, item, _ in tqdm(data): + self.mat[user, item] = 1. + + with TimingContext('Converting to COO matrix and saving'): + scipy.sparse.save_npz(npz_name, self.mat.tocoo(copy=True)) + + def __len__(self): + return (self.nb_neg + 1) * len(self.data) + + def __getitem__(self, idx): + if idx % (self.nb_neg + 1) == 0: + idx = idx // (self.nb_neg + 1) + return self.data[idx][0], self.data[idx][1], np.ones(1, dtype=np.float32) # noqa: E501 + else: + idx = idx // (self.nb_neg + 1) + u = self.data[idx][0] + j = torch.LongTensor(1).random_(0, self.nb_items).item() + while (u, j) in self.mat: + j = torch.LongTensor(1).random_(0, self.nb_items).item() + return u, j, np.zeros(1, dtype=np.float32) + + +def load_test_ratings(fname): + pkl_name = os.path.splitext(fname)[0] + '.pkl' + if os.path.isfile(pkl_name): + with TimingContext('Found test rating pickle file - loading'), open(pkl_name, 'rb') as f: + res = pickle.load(f) + else: + def process_line(line): + tmp = map(int, line.split('\t')[0:2]) + return list(tmp) + with TimingContext('Loading test ratings from csv'), open(fname, 'r') as f: + ratings = map(process_line, tqdm(f, total=wccount(fname))) + res = list(ratings) + with TimingContext('Saving test ratings list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(res, f) + + return res + + +def load_test_negs(fname): + pkl_name = os.path.splitext(fname)[0] + '.pkl' + if os.path.isfile(pkl_name): + with TimingContext('Found test negatives pickle file - loading'), open(pkl_name, 'rb') as f: + res = pickle.load(f) + else: + def process_line(line): + tmp = map(int, line.split('\t')) + return list(tmp) + with TimingContext('Loading test negatives from csv'), open(fname, 'r') as f: + negs = map(process_line, tqdm(f, total=wccount(fname))) + res = list(negs) + with TimingContext('Saving test negatives list pickle'), open(pkl_name, 'wb') as f: + pickle.dump(res, f) + + return res diff --git a/examples/ncf/download_dataset.sh b/examples/ncf/download_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..8876230f14123358e895595d0126cd9537733908 --- /dev/null +++ b/examples/ncf/download_dataset.sh @@ -0,0 +1,16 @@ +function download_20m { + echo "Download ml-20m" + curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip +} + +function download_1m { + echo "Downloading ml-1m" + curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip +} + +if [[ $1 == "ml-1m" ]] +then + download_1m +else + download_20m +fi diff --git a/examples/ncf/load.py b/examples/ncf/load.py new file mode 100644 index 0000000000000000000000000000000000000000..304f43c2bf9836e212a29fa29fb8820320b460d6 --- /dev/null +++ b/examples/ncf/load.py @@ -0,0 +1,68 @@ +from collections import namedtuple + +import pandas as pd + + +RatingData = namedtuple('RatingData', + ['items', 'users', 'ratings', 'min_date', 'max_date']) + + +def describe_ratings(ratings): + info = RatingData(items=len(ratings['item_id'].unique()), + users=len(ratings['user_id'].unique()), + ratings=len(ratings), + min_date=ratings['timestamp'].min(), + max_date=ratings['timestamp'].max()) + print("{ratings} ratings on {items} items from {users} users" + " from {min_date} to {max_date}" + .format(**(info._asdict()))) + return info + + +def process_movielens(ratings, sort=True): + ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s') + if sort: + ratings.sort_values(by='timestamp', inplace=True) + describe_ratings(ratings) + return ratings + + +def load_ml_100k(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='\t', names=names) + return process_movielens(ratings, sort=sort) + + +def load_ml_1m(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='::', names=names, engine='python') + return process_movielens(ratings, sort=sort) + + +def load_ml_10m(filename, sort=True): + names = ['user_id', 'item_id', 'rating', 'timestamp'] + ratings = pd.read_csv(filename, sep='::', names=names, engine='python') + return process_movielens(ratings, sort=sort) + + +def load_ml_20m(filename, sort=True): + ratings = pd.read_csv(filename) + ratings['timestamp'] = pd.to_datetime(ratings['timestamp'], unit='s') + names = {'userId': 'user_id', 'movieId': 'item_id'} + ratings.rename(columns=names, inplace=True) + return process_movielens(ratings, sort=sort) + + +DATASETS = [k.replace('load_', '') for k in locals().keys() if "load_" in k] + + +def get_dataset_name(filename): + for dataset in DATASETS: + if dataset in filename.replace('-', '_').lower(): + return dataset + raise NotImplementedError + + +def implicit_load(filename, sort=True): + func = globals()["load_" + get_dataset_name(filename)] + return func(filename, sort=sort) diff --git a/examples/ncf/logging.conf b/examples/ncf/logging.conf new file mode 100755 index 0000000000000000000000000000000000000000..8db92a75fccb779dc2c02b8e7c668d3cf24363c4 --- /dev/null +++ b/examples/ncf/logging.conf @@ -0,0 +1,38 @@ +[formatters] +keys: simple, time_simple + +[handlers] +keys: console, file + +[loggers] +keys: root, app_cfg + +[formatter_simple] +format: %(message)s + +[formatter_time_simple] +format: %(asctime)s - %(message)s + +[handler_console] +class: StreamHandler +propagate: 0 +args: [] +formatter: simple + +[handler_file] +class: FileHandler +mode: 'w' +args=('%(logfilename)s', 'w') +formatter: time_simple + +[logger_root] +level: INFO +propagate: 1 +handlers: console, file + +[logger_app_cfg] +# Use this logger to log the application configuration and execution environment +level: DEBUG +qualname: app_cfg +propagate: 0 +handlers: file diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py new file mode 100644 index 0000000000000000000000000000000000000000..095ac0db98ba0e7a6e80c812cd6ed3702603891a --- /dev/null +++ b/examples/ncf/ncf.py @@ -0,0 +1,491 @@ +import os +import heapq +import math +import time +from functools import partial +from datetime import datetime +from collections import OrderedDict +from argparse import ArgumentParser +import sys + +import tqdm +import numpy as np +import torch +import torch.nn as nn +from torch import multiprocessing as mp + +import utils +from neumf import NeuMF +from dataset import CFTrainDataset, load_test_ratings, load_test_negs +from convert import (TEST_NEG_FILENAME, TEST_RATINGS_FILENAME, + TRAIN_RATINGS_FILENAME) + +import distiller +import distiller.quantization as quantization +import distiller.apputils as apputils +from distiller.data_loggers import TensorBoardLogger, PythonLogger + +msglogger = None + + +def parse_args(): + parser = ArgumentParser(description="Train a Nerual Collaborative" + " Filtering model") + parser.add_argument('data', type=str, + help='path to test and training data files') + parser.add_argument('-e', '--epochs', type=int, default=20, + help='number of epochs for training') + parser.add_argument('-b', '--batch-size', type=int, default=256, + help='number of examples for each iteration') + parser.add_argument('-f', '--factors', type=int, default=8, + help='number of predictive factors') + parser.add_argument('--layers', nargs='+', type=int, + default=[64, 32, 16, 8], + help='size of hidden layers for MLP') + parser.add_argument('-n', '--negative-samples', type=int, default=4, + help='number of negative examples per interaction') + parser.add_argument('-l', '--learning-rate', type=float, default=0.001, + help='learning rate for optimizer') + parser.add_argument('-k', '--topk', type=int, default=10, + help='rank for test examples to be considered a hit') + parser.add_argument('--no-cuda', action='store_true', + help='use available GPUs') + parser.add_argument('--seed', '-s', type=int, + help='manually set random seed for torch') + parser.add_argument('--threshold', '-t', type=float, + help='stop training early at threshold') + parser.add_argument('--processes', '-p', type=int, default=1, + help='Number of processes for evaluating model') + parser.add_argument('--workers', '-w', type=int, default=8, + help='Number of workers for training DataLoader') + + # Distiller Args + # summary_choices = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] + # parser.add_argument('--summary', type=str, choices=summary_choices, + # help='print a summary of the model, and exit - options: ' + + # ' | '.join(summary_choices)) + parser.add_argument('--load', type=str, metavar='PATH') + parser.add_argument('--reset-optimizer', action='store_true') + parser.add_argument('--eval', '--evaluate', action='store_true') + parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', + help='configuration file for pruning the model (default is to use hard-coded schedule)') + parser.add_argument('--gpus', metavar='DEV_ID', default=None, + help='Comma-separated list of GPU device IDs to be used ' + '(default is to use all available devices)') + parser.add_argument('--out-dir', '-o', dest='output_dir', default=os.path.join('run', 'neumf'), + help='Path to dump logs and checkpoints') + parser.add_argument('--name', metavar='NAME', default=None, help='Experiment name') + parser.add_argument('--log-freq', '--lf', default=100, type=int, metavar='N', help='Logging frequency') + parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, + help='log the parameter tensors histograms to file ' + '(WARNING: this can use significant disk space)') + parser.add_argument('--split-final', '--sf', action='store_true') + parser.add_argument('--eval-fp16', action='store_true') + parser.add_argument('--activation-histograms', '--act-hist', + type=distiller.utils.float_range_argparse_checker(exc_min=True), + metavar='PORTION_OF_TEST_SET', + help='Run the model in evaluation mode on the specified portion of the test dataset and ' + 'generate activation histograms. NOTE: This slows down evaluation significantly') + quantization.add_post_train_quant_args(parser) + + return parser.parse_args() + + +def predict(model, users, items, batch_size=1024, use_cuda=True): + with torch.no_grad(): + batches = [(users[i:i + batch_size], items[i:i + batch_size]) + for i in range(0, len(users), batch_size)] + preds = [] + for user, item in batches: + def proc(x): + x = np.array(x) + x = torch.from_numpy(x) + if use_cuda: + x = x.cuda(async=True) + return torch.autograd.Variable(x) + outp = model(proc(user), proc(item), sigmoid=True) + outp = outp.data.cpu().numpy() + preds += list(outp.flatten()) + return preds + + +def _calculate_hit(ranked, test_item): + return int(test_item in ranked) + + +def _calculate_ndcg(ranked, test_item): + for i, item in enumerate(ranked): + if item == test_item: + return math.log(2) / math.log(i + 2) + return 0. + + +def eval_one(rating, items, model, K, use_cuda=True): + user = rating[0] + test_item = rating[1] + items.append(test_item) + # items.insert(0, test_item) + users = [user] * len(items) + predictions = predict(model, users, items, use_cuda=use_cuda) + + map_item_score = {item: pred for item, pred in zip(items, predictions)} + ranked = heapq.nlargest(K, map_item_score, key=map_item_score.get) + + hit = _calculate_hit(ranked, test_item) + ndcg = _calculate_ndcg(ranked, test_item) + # return user, hit, ndcg + return hit, ndcg + + +def val_epoch(model, ratings, negs, K, use_cuda=True, output=None, epoch=None, + processes=1, num_users=-1): + if epoch is None: + msglogger.info("Initial evaluation") + else: + msglogger.info("Epoch {} evaluation".format(epoch)) + start = datetime.now() + model.eval() + + if num_users > 0: + ratings = ratings[:num_users] + negs = negs[:num_users] + + if processes > 1: + context = mp.get_context('spawn') + _eval_one = partial(eval_one, model=model, K=K, use_cuda=use_cuda) + with context.Pool(processes=processes) as workers: + hits_and_ndcg = workers.starmap(_eval_one, zip(ratings, negs)) + + # pbar = tqdm.tqdm(total=len(ratings)) + # hits_and_ndcg = [None] * len(ratings) + # + # def update_pbar(idx_hit_ncdg): + # idx, hit, ncdg = idx_hit_ncdg + # hits_and_ndcg[idx] = (hit, ncdg) + # pbar.update() + # + # context = mp.get_context('spawn') + # pool = context.Pool(processes=processes) + # for idx, (rating, items) in enumerate(zip(ratings, negs)): + # pool.apply_async(eval_one, args=(rating, items, model, K, use_cuda), callback=update_pbar) + # pool.close() + # pool.join() + # pbar.close() + + hits, ndcgs = zip(*hits_and_ndcg) + else: + hits, ndcgs = [], [] + with tqdm.tqdm(zip(ratings, negs), total=len(ratings)) as t: + for rating, items in t: + hit, ndcg = eval_one(rating, items, model, K, use_cuda=use_cuda) + hits.append(hit) + ndcgs.append(ndcg) + steps_completed = len(hits) + 1 + if steps_completed % 100 == 0: + t.set_description('HR@10 = {0:.4f}, NDCG = {1:.4f}'.format(np.mean(hits), np.mean(ndcgs))) + + hits = np.array(hits, dtype=np.float32) + ndcgs = np.array(ndcgs, dtype=np.float32) + + end = datetime.now() + if output is not None: + result = OrderedDict() + result['timestamp'] = datetime.now() + result['duration'] = end - start + result['epoch'] = epoch + result['K'] = K + result['hit_rate'] = np.mean(hits) + result['NDCG'] = np.mean(ndcgs) + utils.save_result(result, output) + + return hits, ndcgs + + +def main(): + global msglogger + + script_dir = os.path.dirname(__file__) + module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) + + args = parse_args() + + # Distiller loggers + msglogger = apputils.config_pylogger('logging.conf', args.name, output_dir=args.output_dir) + tflogger = TensorBoardLogger(msglogger.logdir) + # tflogger.log_gradients = True + # pylogger = PythonLogger(msglogger) + + if args.seed is not None: + msglogger.info("Using seed = {}".format(args.seed)) + torch.manual_seed(args.seed) + np.random.seed(seed=args.seed) + + args.qe_mode = str(args.qe_mode).split('.')[1] + args.qe_clip_acts = str(args.qe_clip_acts).split('.')[1] + + apputils.log_execution_env_state(sys.argv, gitroot=module_path) + + if args.gpus is not None: + try: + args.gpus = [int(s) for s in args.gpus.split(',')] + except ValueError: + msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') + exit(1) + # if len(args.gpus) > 1: + # msglogger.error('ERROR: Only single GPU supported for NCF') + # exit(1) + available_gpus = torch.cuda.device_count() + for dev_id in args.gpus: + if dev_id >= available_gpus: + msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' + .format(dev_id, available_gpus)) + exit(1) + # Set default device in case the first one on the list != 0 + torch.cuda.set_device(args.gpus[0]) + + # Save configuration to file + config = {k: v for k, v in args.__dict__.items()} + config['timestamp'] = "{:.0f}".format(datetime.utcnow().timestamp()) + config['local_timestamp'] = str(datetime.now()) + # run_dir = "./run/neumf/{}".format(config['timestamp']) + run_dir = msglogger.logdir + msglogger.info("Saving config and results to {}".format(run_dir)) + if not os.path.exists(run_dir) and run_dir != '': + os.makedirs(run_dir) + utils.save_config(config, run_dir) + + # Check that GPUs are actually available + use_cuda = not args.no_cuda and torch.cuda.is_available() + + t1 = time.time() + # Load Data + training = not (args.eval or args.qe_calibration or args.activation_histograms) + msglogger.info('Loading data') + if training: + train_dataset = CFTrainDataset( + os.path.join(args.data, TRAIN_RATINGS_FILENAME), args.negative_samples) + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True) + nb_users, nb_items = train_dataset.nb_users, train_dataset.nb_items + else: + train_dataset = None + train_dataloader = None + nb_users, nb_items = (276986, 53488) + + test_ratings = load_test_ratings(os.path.join(args.data, TEST_RATINGS_FILENAME)) # noqa: E501 + test_negs = load_test_negs(os.path.join(args.data, TEST_NEG_FILENAME)) + + msglogger.info('Load data done [%.1f s]. #user=%d, #item=%d, #train=%s, #test=%d' + % (time.time()-t1, nb_users, nb_items, str(train_dataset.mat.nnz) if training else 'N/A', + len(test_ratings))) + + # Create model + model = NeuMF(nb_users, nb_items, + mf_dim=args.factors, mf_reg=0., + mlp_layer_sizes=args.layers, + mlp_layer_regs=[0. for i in args.layers], + split_final=args.split_final) + if use_cuda: + # Move model and loss to GPU + # if len(args.gpus) > 1: + # model = torch.nn.DataParallel(model, device_ids=args.gpus) + model = model.cuda() + msglogger.info(model) + msglogger.info("{} parameters".format(utils.count_parameters(model))) + + # Save model text description + with open(os.path.join(run_dir, 'model.txt'), 'w') as file: + file.write(str(model)) + + compression_scheduler = None + start_epoch = 0 + optimizer = None + if args.load: + if training: + model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(model, args.load) + if args.reset_optimizer: + start_epoch = 0 + optimizer = None + else: + model = apputils.load_lean_checkpoint(model, args.load) + + # Add loss to graph + criterion = nn.BCEWithLogitsLoss() + + if use_cuda: + criterion = criterion.cuda() + + if training and optimizer is None: + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + msglogger.info('Optimizer Type: %s', type(optimizer)) + msglogger.info('Optimizer Args: %s', optimizer.defaults) + + if args.compress: + compression_scheduler = distiller.file_config(model, optimizer, args.compress) + model.cuda() + + # Create files for tracking training + valid_results_file = os.path.join(run_dir, 'valid_results.csv') + + if args.qe_calibration or args.activation_histograms: + calib = {'portion': args.qe_calibration, + 'desc_str': 'quantization calibration stats', + 'collect_func': partial(distiller.data_loggers.collect_quant_stats, inplace_runtime_check=True, + disable_inplace_attrs=True)} + hists = {'portion': args.activation_histograms, + 'desc_str': 'activation histograms', + 'collect_func': partial(distiller.data_loggers.collect_histograms, activation_stats=None, nbins=2048, + save_hist_imgs=True)} + d = calib if args.qe_calibration else hists + + distiller.utils.assign_layer_fq_names(model) + num_users = int(np.floor(len(test_ratings) * d['portion'])) + msglogger.info( + "Generating {} based on {:.1%} of the test-set ({} users)".format(d['desc_str'], d['portion'], num_users)) + + test_fn = partial(val_epoch, ratings=test_ratings, negs=test_negs, K=args.topk, use_cuda=use_cuda, + processes=args.processes, num_users=num_users) + d['collect_func'](model=model, test_fn=test_fn, save_dir=run_dir, classes=None) + + return 0 + + if args.eval: + if args.quantize_eval and args.qe_calibration is None: + model.cpu() + quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) + quantizer.prepare_model() + model.cuda() + + distiller.utils.assign_layer_fq_names(model) + + if args.eval_fp16: + model = model.half() + + # Calculate initial Hit Ratio and NDCG + begin = time.time() + hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk, + use_cuda=use_cuda, processes=args.processes) + val_time = time.time() - begin + hit_rate = np.mean(hits) + msglogger.info('Initial HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, val_time = {val_time:.2f}' + .format(K=args.topk, hit_rate=hit_rate, ndcg=np.mean(ndcgs), val_time=val_time)) + hit_rate = 0 + + if args.quantize_eval: + checkpoint_name = 'quantized' + apputils.save_checkpoint(0, 'NCF', model, optimizer=None, extras={'quantized_hr@10': hit_rate}, + name='_'.join([args.name, 'quantized']) if args.name else checkpoint_name, + dir=msglogger.logdir) + return 0 + + total_samples = len(train_dataloader.sampler) + steps_per_epoch = math.ceil(total_samples / args.batch_size) + best_hit_rate = 0 + best_epoch = 0 + for epoch in range(start_epoch, args.epochs): + msglogger.info('') + model.train() + losses = utils.AverageMeter() + + begin = time.time() + + if compression_scheduler: + compression_scheduler.on_epoch_begin(epoch, optimizer) + + loader = tqdm.tqdm(train_dataloader) + for batch_index, (user, item, label) in enumerate(loader): + user = torch.autograd.Variable(user, requires_grad=False) + item = torch.autograd.Variable(item, requires_grad=False) + label = torch.autograd.Variable(label, requires_grad=False) + if use_cuda: + user = user.cuda(async=True) + item = item.cuda(async=True) + label = label.cuda(async=True) + + if compression_scheduler: + compression_scheduler.on_minibatch_begin(epoch, batch_index, steps_per_epoch, optimizer) + + outputs = model(user, item) + loss = criterion(outputs, label) + + if compression_scheduler: + compression_scheduler.before_backward_pass(epoch, batch_index, steps_per_epoch, loss, optimizer, + return_loss_components=False) + + losses.update(loss.data.item(), user.size(0)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if compression_scheduler: + compression_scheduler.on_minibatch_end(epoch, batch_index, steps_per_epoch, optimizer) + + # Save stats to file + description = ('Epoch {} Loss {loss.val:.4f} ({loss.avg:.4f})' + .format(epoch, loss=losses)) + loader.set_description(description) + + steps_completed = batch_index + 1 + if steps_completed % args.log_freq == 0: + stats_dict = OrderedDict() + stats_dict['Loss'] = losses.avg + stats = ('Performance/Training/', stats_dict) + params = model.named_parameters() if args.log_params_histograms else None + distiller.log_training_progress(stats, params, epoch, steps_completed, steps_per_epoch, args.log_freq, + [tflogger]) + + tflogger.log_model_buffers(model, ['tracked_min', 'tracked_max'], 'Quant/Train/Acts/TrackedMinMax', + epoch, steps_completed, steps_per_epoch, args.log_freq) + + train_time = time.time() - begin + begin = time.time() + hits, ndcgs = val_epoch(model, test_ratings, test_negs, args.topk, + use_cuda=use_cuda, output=valid_results_file, + epoch=epoch, processes=args.processes) + val_time = time.time() - begin + + if compression_scheduler: + compression_scheduler.on_epoch_end(epoch, optimizer) + + hit_rate = np.mean(hits) + mean_ndcgs = np.mean(ndcgs) + + stats_dict = OrderedDict() + stats_dict['HR@{0}'.format(args.topk)] = hit_rate + stats_dict['NDCG@{0}'.format(args.topk)] = mean_ndcgs + stats = ('Performance/Validation/', stats_dict) + distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, + loggers=[tflogger]) + + msglogger.info('Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}, AvgTrainLoss = {loss.avg:.4f}, ' + 'train_time = {train_time:.2f}, val_time = {val_time:.2f}'.format( + epoch=epoch, K=args.topk, hit_rate=hit_rate, ndcg=mean_ndcgs, + loss=losses, train_time=train_time, val_time=val_time)) + + is_best = False + if hit_rate > best_hit_rate: + best_hit_rate = hit_rate + is_best = True + best_epoch = epoch + extras = {'current_hr@10': hit_rate, + 'best_hr@10': best_hit_rate, + 'best_epoch': best_epoch} + apputils.save_checkpoint(epoch, 'NCF', model, optimizer, compression_scheduler, extras, is_best, dir=run_dir) + + if args.threshold is not None: + if np.mean(hits) >= args.threshold: + msglogger.info("Hit threshold of {}".format(args.threshold)) + break + + +if __name__ == '__main__': + try: + main() + except KeyboardInterrupt: + print("\n-- KeyboardInterrupt --") + finally: + if msglogger is not None: + msglogger.info('') + msglogger.info('Log file for this run: ' + os.path.realpath(msglogger.log_filename)) diff --git a/examples/ncf/neumf.py b/examples/ncf/neumf.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf71b312b1cf799f2ba0d26ee19611ba3d3b0d3 --- /dev/null +++ b/examples/ncf/neumf.py @@ -0,0 +1,151 @@ +import numpy as np +import torch +import torch.nn as nn + +import distiller.modules + +import logging +import os +msglogger = logging.getLogger() + + +class NeuMF(nn.Module): + def __init__(self, nb_users, nb_items, + mf_dim, mf_reg, + mlp_layer_sizes, mlp_layer_regs, split_final=False): + if len(mlp_layer_sizes) != len(mlp_layer_regs): + raise RuntimeError('u dummy, layer_sizes != layer_regs!') + if mlp_layer_sizes[0] % 2 != 0: + raise RuntimeError('u dummy, mlp_layer_sizes[0] % 2 != 0') + super(NeuMF, self).__init__() + + self.mf_dim = mf_dim + self.mlp_layer_sizes = mlp_layer_sizes + + nb_mlp_layers = len(mlp_layer_sizes) + + # TODO: regularization? + self.mf_user_embed = nn.Embedding(nb_users, mf_dim) + self.mf_item_embed = nn.Embedding(nb_items, mf_dim) + self.mlp_user_embed = nn.Embedding(nb_users, mlp_layer_sizes[0] // 2) + self.mlp_item_embed = nn.Embedding(nb_items, mlp_layer_sizes[0] // 2) + + self.mf_mult = distiller.modules.EltwiseMult() + self.mlp_concat = distiller.modules.Concat(dim=1) + + self.mlp = nn.ModuleList() + self.mlp_relu = nn.ModuleList() + for i in range(1, nb_mlp_layers): + self.mlp.extend([nn.Linear(mlp_layer_sizes[i - 1], mlp_layer_sizes[i])]) # noqa: E501 + self.mlp_relu.extend([nn.ReLU()]) + + self.split_final = split_final + if not split_final: + self.final_concat = distiller.modules.Concat(dim=1) + self.final = nn.Linear(mlp_layer_sizes[-1] + mf_dim, 1) + else: + self.final_mlp = nn.Linear(mlp_layer_sizes[-1], 1) + self.final_mf = nn.Linear(mf_dim, 1) + self.final_add = distiller.modules.EltwiseAdd() + + self.sigmoid = nn.Sigmoid() + + self.mf_user_embed.weight.data.normal_(0., 0.01) + self.mf_item_embed.weight.data.normal_(0., 0.01) + self.mlp_user_embed.weight.data.normal_(0., 0.01) + self.mlp_item_embed.weight.data.normal_(0., 0.01) + + def golorot_uniform(layer): + fan_in, fan_out = layer.in_features, layer.out_features + limit = np.sqrt(6. / (fan_in + fan_out)) + layer.weight.data.uniform_(-limit, limit) + + def lecunn_uniform(layer): + fan_in, fan_out = layer.in_features, layer.out_features # noqa: F841, E501 + limit = np.sqrt(3. / fan_in) + layer.weight.data.uniform_(-limit, limit) + + for layer in self.mlp: + if type(layer) != nn.Linear: + continue + golorot_uniform(layer) + if not split_final: + lecunn_uniform(self.final) + else: + lecunn_uniform(self.final_mlp) + lecunn_uniform(self.final_mf) + + # self.post_embed_device = torch.device('cpu') + + def load_state_dict(self, state_dict, strict=True): + if 'final.weight' in state_dict and self.split_final: + # Loading no-split checkpoint into split model + + # MF weights come first, then MLP + final_weight = state_dict.pop('final.weight') + state_dict['final_mf.weight'] = final_weight[0][:self.mf_dim].unsqueeze(0) + state_dict['final_mlp.weight'] = final_weight[0][self.mf_dim:].unsqueeze(0) + + # Split bias 50-50 + final_bias = state_dict.pop('final.bias') + state_dict['final_mf.bias'] = final_bias * 0.5 + state_dict['final_mlp.bias'] = final_bias * 0.5 + elif 'final_mf.weight' in state_dict and not self.split_final: + # Loading split checkpoint into no-split model + state_dict['final.weight'] = torch.cat((state_dict.pop('final_mf.weight')[0], + state_dict.pop('final_mlp.weight')[0])).unsqueeze(0) + state_dict['final.bias'] = state_dict.pop('final_mf.bias') + state_dict.pop('final_mlp.bias') + + super(NeuMF, self).load_state_dict(state_dict, strict) + + def forward(self, user, item, sigmoid=False): + xmfu = self.mf_user_embed(user) # .to(self.post_embed_device) + xmfi = self.mf_item_embed(item) # .to(self.post_embed_device) + xmf = self.mf_mult(xmfu, xmfi) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'mf_mult'), xmf.cpu().detach().numpy()) + + xmlpu = self.mlp_user_embed(user) # .to(self.post_embed_device) + xmlpi = self.mlp_item_embed(item) # .to(self.post_embed_device) + xmlp = self.mlp_concat(xmlpu, xmlpi) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'mlp_concat'), xmlp.cpu().detach().numpy()) + for i, (layer, act) in enumerate(zip(self.mlp, self.mlp_relu)): + xmlp = layer(xmlp) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'mlp.{}'.format(i)), xmlp.detach().cpu().numpy()) + xmlp = act(xmlp) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'mlp_relu.{}'.format(i)), xmlp.detach().cpu().numpy()) + + if not self.split_final: + x = self.final_concat(xmf, xmlp) + x = self.final(x) + else: + xmf = self.final_mf(xmf) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'final_mf'), xmf.detach().cpu().numpy()) + xmlp = self.final_mlp(xmlp) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'final_mlp'), xmlp.detach().cpu().numpy()) + x = self.final_add(xmf, xmlp) + # @DEBUG + # np.save(os.path.join(msglogger.logdir, 'final_add'), x.detach().cpu().numpy()) + if sigmoid: + x = self.sigmoid(x) + return x + + # def to_cuda(self, device=None, embeds_on_gpu=True): + # self.post_embed_device = device if device is not None else torch.device('cuda') + # + # if embeds_on_gpu: + # return self.cuda(device=device) + # + # for m in self.modules(): + # if isinstance(m, nn.Embedding): + # m.cpu() + # else: + # m.cuda(device=device) + # + # return self + diff --git a/examples/ncf/requirements.txt b/examples/ncf/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cfa65ea03548eeef5e7b2d7d60ea620106ff8e49 --- /dev/null +++ b/examples/ncf/requirements.txt @@ -0,0 +1,3 @@ +tqdm==4.20.0 +scipy +pandas diff --git a/examples/ncf/run_and_time.sh b/examples/ncf/run_and_time.sh new file mode 100755 index 0000000000000000000000000000000000000000..791248e4494ec03a7d351e3e52a212fbded4538d --- /dev/null +++ b/examples/ncf/run_and_time.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# runs benchmark and reports time to convergence +# to use the script: +# run_and_time.sh <random seed 1-5> + +THRESHOLD=0.635 +BASEDIR=$(dirname -- "$0") + +# start timing +start=$(date +%s) +start_fmt=$(date +%Y-%m-%d\ %r) +echo "STARTING TIMING RUN AT $start_fmt" + +# Get command line seed +seed=${1:-1} + +echo "unzip ml-20m.zip" +if unzip ml-20m.zip +then + echo "Start processing ml-20m/ratings.csv" + t0=$(date +%s) + python $BASEDIR/convert.py ml-20m/ratings.csv ml-20m --negatives 999 + t1=$(date +%s) + delta=$(( $t1 - $t0 )) + echo "Finish processing ml-20m/ratings.csv in $delta seconds" + + echo "Start training" + t0=$(date +%s) + python $BASEDIR/ncf.py ml-20m -l 0.0005 -b 2048 --layers 256 256 128 64 -f 64 \ + --seed $seed --threshold $THRESHOLD --processes 10 + t1=$(date +%s) + delta=$(( $t1 - $t0 )) + echo "Finish training in $delta seconds" + + # end timing + end=$(date +%s) + end_fmt=$(date +%Y-%m-%d\ %r) + echo "ENDING TIMING RUN AT $end_fmt" + + + # report result + result=$(( $end - $start )) + result_name="recommendation" + + + echo "RESULT,$result_name,$seed,$result,$USER,$start_fmt" +else + echo "Problem unzipping ml-20.zip" + echo "Please run 'download_data.sh && verify_datset.sh' first" +fi + + + + + diff --git a/examples/ncf/utils.py b/examples/ncf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4395830c050b54fbfc4212dff4bbf5d3c755ed00 --- /dev/null +++ b/examples/ncf/utils.py @@ -0,0 +1,41 @@ +import os +import json +from functools import reduce + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def count_parameters(model): + c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters()) + return sum(c) + + +def save_config(config, run_dir): + path = os.path.join(run_dir, "config_{}.json".format(config['timestamp'])) + with open(path, 'w') as config_file: + json.dump(config, config_file) + config_file.write('\n') + + +def save_result(result, path): + write_heading = not os.path.exists(path) + with open(path, mode='a') as out: + if write_heading: + out.write(",".join([str(k) for k, v in result.items()]) + '\n') + out.write(",".join([str(v) for k, v in result.items()]) + '\n') diff --git a/examples/ncf/verify_dataset.sh b/examples/ncf/verify_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..208d7602a8fad8bf3151edf90e8ecd2641195e14 --- /dev/null +++ b/examples/ncf/verify_dataset.sh @@ -0,0 +1,44 @@ +function get_checker { + if [[ "$OSTYPE" == "darwin"* ]]; then + checkmd5=md5 + else + checkmd5=md5sum + fi + + echo $checkmd5 +} + + +function verify_1m { + # From: curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip.md5 + hash=<(echo "MD5 (ml-1m.zip) = c4d9eecfca2ab87c1945afe126590906") + local checkmd5=$(get_checker) + if diff <($checkmd5 ml-1m.zip) $hash &> /dev/null + then + echo "PASSED" + else + echo "FAILED" + fi +} + +function verify_20m { + # From: curl -O http://files.grouplens.org/datasets/movielens/ml-20m.zip.md5 + hash=<(echo "MD5 (ml-20m.zip) = cd245b17a1ae2cc31bb14903e1204af3") + local checkmd5=$(get_checker) + + if diff <($checkmd5 ml-20m.zip) $hash &> /dev/null + then + echo "PASSED" + else + echo "FAILED" + fi + +} + + +if [[ $1 == "ml-1m" ]] +then + verify_1m +else + verify_20m +fi diff --git a/examples/quantization/quant_aware_train/ncf_quant_aware_train_linear.yaml b/examples/quantization/quant_aware_train/ncf_quant_aware_train_linear.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1c2d7c4d743fb89802087d8a778db5d1528dc54 --- /dev/null +++ b/examples/quantization/quant_aware_train/ncf_quant_aware_train_linear.yaml @@ -0,0 +1,29 @@ +quantizers: + linear_quantizer: + class: NCFQuantAwareTrainQuantizer + bits_activations: 8 + bits_weights: 8 + bits_bias: 32 + mode: 'SYMMETRIC' # Can try "SYMMETRIC" as well + ema_decay: 0.999 # Decay value for exponential moving average tracking of activation ranges + per_channel_wts: True + overrides: + # We want to quantize the last FC layer prior to the sigmoid. So - We set up the quantizer to add fake-quantization + # layers after FC layers. But - here we override so that this doesn't actually happen in any of the early FC layers + mlp\.*: + bits_activations: null + bits_weights: 8 + bits_bias: 32 + final_concat: + bits_activations: null + bits_weights: null + bits_bias: null + +policies: + - quantizer: + instance_name: linear_quantizer + # For now putting a large range here, which should cover both training from scratch or resuming from some + # pre-trained checkpoint at some unknown epoch + starting_epoch: 0 + ending_epoch: 300 + frequency: 1 \ No newline at end of file