diff --git a/examples/ncf/README.md b/examples/ncf/README.md index fc477a629975fbc7ae012e5ab8dda95ad785f5e8..827893e1b296ebd03fc4a4c7136201d2df28a46d 100644 --- a/examples/ncf/README.md +++ b/examples/ncf/README.md @@ -1,6 +1,8 @@ # NCF - Neural Collaborative Filtering -The NCF implementation provided here is based on the implementation found in the MLPerf Training GitHub repository, specifically on the last revision of the code before the switch to the extended dataset. See [here](https://github.com/mlperf/training/tree/fe17e837ed12974d15c86d5173fe8f2c188434d5/recommendation/pytorch). +The NCF implementation provided here is based on the implementation found in the MLPerf Training GitHub repository. +This sample is not based on the latest implementation in MLPerf, it is based on an earlier revision which uses the ml-20m dataset. The latest code uses a much larger dataset. We plan to move to the latest version in the near future. +You can fine the revision this sample is based on [here](https://github.com/mlperf/training/tree/fe17e837ed12974d15c86d5173fe8f2c188434d5/recommendation/pytorch). We've made several modifications to the code: * Removed all MLPerf specific code including logging @@ -8,8 +10,8 @@ We've made several modifications to the code: * Added calls to Distiller compression APIs * Added progress indication in training and evaluation flows * In `neumf.py`: - * Added option to split final FC layer - * Replaced all functional calls with modules so they can be detected by Distiller + * Added option to split final the FC layer (the `split_final` parameter). See [below](#side-note-splitting-the-final-fc-layer). + * Replaced all functional calls with modules so they can be detected by Distiller, as per this [guide](https://nervanasystems.github.io/distiller/prepare_model_quant.html) in the Distiller docs. * In `dataset.py`: * Speed up data loading - On first data will is loaded from CSVs and then pickled. On subsequent runs the pickle is loaded. This is much faster than the original implementation, but still very slow. * Added progress indication during data load process @@ -24,34 +26,121 @@ The model trains on binary information about whether or not a user interacted wi ### Steps to configure machine -1. Install `unzip` and `curl` +* Install `unzip` and `curl` + + ```bash + sudo apt-get install unzip curl + ``` + +* Make sure the latest Distiller requirements are installed + + ```bash + # Relative to this sample directory + cd <distiller-repo-root> + pip install -e . + ``` + +* Download and verify data + + ```bash + cd <distiller-repo-root>/examples/ncf + # Creates ml-20.zip + source ../download_dataset.sh + # Confirms the MD5 checksum of ml-20.zip + source ../verify_dataset.sh + ``` + +## Running the Sample + +### Train a Base FP32 Model + +We train a model with the following parameters: + +* MLP Side + * Embedding size per user / item: 128 + * FC layer sizes: 256x256 --> 256x128 --> 128x64 +* MF (matrix factorization) Side + * Embedding size per user / item: 64 +* Therefore, the final FC layer size is: 128x1 + +Adam optimizer is used, with an initial learning rate of 0.0005. Batch size is 2048. Convergence is obtained after 7 epochs. ```bash -sudo apt-get install unzip curl +python ncf.py ml-20m -l 0.0005 -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --processes 10 -o run/neumf/base_fp32 +... +Epoch 0 Loss 0.1179 (0.1469): 100%|█████████████████████████████| 48491/48491 [07:04<00:00, 114.23it/s] +Epoch 0 evaluation +Epoch 0: HR@10 = 0.5738, NDCG@10 = 0.3367, AvgTrainLoss = 0.1469, train_time = 424.52, val_time = 47.04 +... +Epoch 6 Loss 0.0914 (0.0943): 100%|█████████████████████████████| 48491/48491 [06:47<00:00, 118.90it/s] +Epoch 6 evaluation +Epoch 6: HR@10 = 0.6355, NDCG@10 = 0.3820, AvgTrainLoss = 0.0943, train_time = 407.84, val_time = 62.99 +``` + +The hit-rate of the base model is 63.55. + +### Side-Note: Splitting the Final FC Layer + +As mentioned above, we added an option to split the final FC layer of the model (the `split_final` parameter in `NeuMF.__init__`). + +The reasoning behind this is that the input to the final FC layer in NCF is a concatenation of the outputs of the MLP and MF "branches". These outputs have very different dynamic ranges. +In the model we just trained, the MLP branch output range is [0 .. 203] while the MF branch output range is [-6.3 .. 7.4]. When doing quantized concatenation, we have to accommodate the larger range, which leads to a large quantization error for the data that came from the MF branch. When quantizing to 8-bits, the MF branch will cover only 10 bins out of the 256 bins, which means just over 3-bits. +The mitigation we use is to split the final FC layer as follows: + +``` + Before Split: After Split: + ------------- ------------ + MF_OUT MLP_OUT MF_OUT MLP_OUT + \ / | | + \ / ---> MF_FC MLP_FC + CONCAT \ / + | \ / + FINAL_FC \ / + ADD ``` +After splitting, the two inputs to the add operation have ranges [-283 .. 40] from the MLP side and [-54 .. 47] from the MF side. While the problem isn't completely solved, it's much better than before. Now the MF covers 126 bins, which is almost 7-bits. -2. Install required python packages +Note that in FP32 the 2 modes are functionally identical. The split final option is for evaluation only, and we take care to convert the model trained without splitting into a split model when loading the checkpoint. + +### Collect Quantization Stats for Post-Training Quantization + +We generated stats for both the non-split and split case. These are the `quantization_stats_no_split.yaml` and `quantization_stats_split.yaml` files in the example folder. + +For reference, the command lines used to generate these are: ```bash -pip install -r requirements.txt +python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1 +python ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --qe-calibration 0.1 --split-final ``` +Note that `--qe-calibration 0.1` means that we use 10% of the test dataset for the stats collection. + +### Post-Training Quantization Experiments + +We'll use the following settings for quantization: -3. Download and verify data +* 8-bits for weights and activations: `--qeba 8 --qebw 8` +* Asymmetric: `--qem asym_u` +* Per-channel: `--qepc` + +Let's see the difference splitting the final FC layer makes in terms of overall accuracy: ```bash -# Creates ml-20.zip -source ../download_dataset.sh -# Confirms the MD5 checksum of ml-20.zip -source ../verify_dataset.sh +ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --qe-stats-file quantization_stats_no_split.yaml +... +Initial HR@10 = 0.4954, NDCG@10 = 0.2802, val_time = 521.11 ``` -## Running the Sample +```bash +ncf.py ml-20m -b 2048 --layers 256 256 128 64 -f 64 --seed 1 --load run/neumf/base_fp32/best.pth.tar --evaluate --quantize-eval --qeba 8 --qebw 8 --qem asym_u --qepc --split-final --qe-stats-file quantization_stats_split.yaml +... +HR@10 = 0.6278, NDCG@10 = 0.3760, val_time = 601.87 +``` -### TODO: Add some Distiller specific example command line +We can see that without splitting, we get ~14% degradation in hit-rate. With splitting we gain almost all of the accuracy back, with about 0.8% degradation. -## Dataset/Environment +## Dataset / Environment -### Publication/Attribution +### Publication / Attribution Harper, F. M. & Konstan, J. A. (2015), 'The MovieLens Datasets: History and Context', ACM Trans. Interact. Intell. Syst. 5(4), 19:1--19:19. @@ -80,17 +169,3 @@ Data is traversed randomly with 4 negative examples selected on average for ever Xiangnan He, Lizi Liao, Hanwang Zhang, Liqiang Nie, Xia Hu and Tat-Seng Chua (2017). [Neural Collaborative Filtering](http://dl.acm.org/citation.cfm?id=3052569). In Proceedings of WWW '17, Perth, Australia, April 03-07, 2017. The author's original code is available at [hexiangnan/neural_collaborative_filtering](https://github.com/hexiangnan/neural_collaborative_filtering). - -## Quality - -### Quality metric - -Hit rate at 10 (HR@10) with 999 negative items. - -### Evaluation frequency - -After every epoch through the training data. - -### Evaluation thoroughness - -Every users last item rated, i.e. all held out positive examples. diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py index b3148ea97079a4c6e9f9c834e6f4743bcd012cce..98376ca825ae800cbfbc8aeae08df176b765d174 100644 --- a/examples/ncf/ncf.py +++ b/examples/ncf/ncf.py @@ -155,23 +155,6 @@ def val_epoch(model, ratings, negs, K, use_cuda=True, output=None, epoch=None, _eval_one = partial(eval_one, model=model, K=K, use_cuda=use_cuda) with context.Pool(processes=processes) as workers: hits_and_ndcg = workers.starmap(_eval_one, zip(ratings, negs)) - - # pbar = tqdm.tqdm(total=len(ratings)) - # hits_and_ndcg = [None] * len(ratings) - # - # def update_pbar(idx_hit_ncdg): - # idx, hit, ncdg = idx_hit_ncdg - # hits_and_ndcg[idx] = (hit, ncdg) - # pbar.update() - # - # context = mp.get_context('spawn') - # pool = context.Pool(processes=processes) - # for idx, (rating, items) in enumerate(zip(ratings, negs)): - # pool.apply_async(eval_one, args=(rating, items, model, K, use_cuda), callback=update_pbar) - # pool.close() - # pool.join() - # pbar.close() - hits, ndcgs = zip(*hits_and_ndcg) else: hits, ndcgs = [], [] @@ -231,9 +214,9 @@ def main(): except ValueError: msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') exit(1) - # if len(args.gpus) > 1: - # msglogger.error('ERROR: Only single GPU supported for NCF') - # exit(1) + if len(args.gpus) > 1: + msglogger.error('ERROR: Only single GPU supported for NCF') + exit(1) available_gpus = torch.cuda.device_count() for dev_id in args.gpus: if dev_id >= available_gpus: @@ -247,7 +230,6 @@ def main(): config = {k: v for k, v in args.__dict__.items()} config['timestamp'] = "{:.0f}".format(datetime.utcnow().timestamp()) config['local_timestamp'] = str(datetime.now()) - # run_dir = "./run/neumf/{}".format(config['timestamp']) run_dir = msglogger.logdir msglogger.info("Saving config and results to {}".format(run_dir)) if not os.path.exists(run_dir) and run_dir != '': @@ -287,9 +269,6 @@ def main(): mlp_layer_regs=[0. for i in args.layers], split_final=args.split_final) if use_cuda: - # Move model and loss to GPU - # if len(args.gpus) > 1: - # model = torch.nn.DataParallel(model, device_ids=args.gpus) model = model.cuda() msglogger.info(model) msglogger.info("{} parameters".format(utils.count_parameters(model))) diff --git a/examples/ncf/quantization_stats_no_split.yaml b/examples/ncf/quantization_stats_no_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f32232c563dce72256d95b8a9e36363d5fe1abf4 --- /dev/null +++ b/examples/ncf/quantization_stats_no_split.yaml @@ -0,0 +1,294 @@ +mf_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) +mf_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) +mlp_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) +mlp_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) +mf_mult: + inputs: + 0: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) + 1: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) +mlp_concat: + inputs: + 0: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) + 1: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) +mlp.0: + inputs: + 0: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) + output: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) +mlp.1: + inputs: + 0: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) + output: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) +mlp.2: + inputs: + 0: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) + output: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) +mlp_relu.0: + inputs: + 0: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) + output: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) +mlp_relu.1: + inputs: + 0: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) + output: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) +mlp_relu.2: + inputs: + 0: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) + output: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) +final_concat: + inputs: + 0: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) + 1: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 203.7071990966797 + avg_min: -1.4091147140708975 + avg_max: 29.95700872795735 + mean: 2.452035934175203 + std: 8.46057860792389 + shape: (1000, 128) +final: + inputs: + 0: + min: -6.388758659362793 + max: 203.7071990966797 + avg_min: -1.4091147140708975 + avg_max: 29.95700872795735 + mean: 2.452035934175203 + std: 8.46057860792389 + shape: (1000, 128) + output: + min: -264.23663330078125 + max: 10.719743728637695 + avg_min: -64.09727749207161 + avg_max: 4.514405594118789 + mean: -27.331936557333087 + std: 29.674832823876194 + shape: (1000, 1) +sigmoid: + inputs: + 0: + min: -264.23663330078125 + max: 10.719743728637695 + avg_min: -64.09727749207161 + avg_max: 4.514405594118789 + mean: -27.331936557333087 + std: 29.674832823876194 + shape: (1000, 1) + output: + min: 0.0 + max: 0.9999779462814331 + avg_min: 7.22119551814259e-08 + avg_max: 0.9796236092019589 + mean: 0.025780337072657727 + std: 0.11967490732565136 + shape: (1000, 1) diff --git a/examples/ncf/quantization_stats_split.yaml b/examples/ncf/quantization_stats_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afbc00409ba2d336701d5a21095b67c68400db21 --- /dev/null +++ b/examples/ncf/quantization_stats_split.yaml @@ -0,0 +1,312 @@ +mf_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) +mf_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) +mlp_user_embed: + inputs: + 0: + min: 0 + max: 13848 + avg_min: 6924.0 + avg_max: 6924.0 + mean: 6924.0 + std: 3997.8620729189115 + shape: (1000) + output: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) +mlp_item_embed: + inputs: + 0: + min: 0 + max: 26743 + avg_min: 32.519748718318944 + avg_max: 26716.69644017626 + mean: 13429.908568849716 + std: 7695.305228284578 + shape: (1000) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) +mf_mult: + inputs: + 0: + min: -2.3631081581115723 + max: 2.634204864501953 + avg_min: -1.071664170499179 + avg_max: 1.10968593620068 + mean: 0.004670127670715561 + std: 0.4853287812380178 + shape: (1000, 64) + 1: + min: -3.886291742324829 + max: 3.3996453285217285 + avg_min: -1.8537866596848787 + avg_max: 1.733846694667886 + mean: -0.1020138903885282 + std: 0.9169524390367664 + shape: (1000, 64) + output: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) +mlp_concat: + inputs: + 0: + min: -2.3077099323272705 + max: 2.019761323928833 + avg_min: -0.8393908124364596 + avg_max: 0.8563097013907461 + mean: -0.0058890159863465566 + std: 0.33145011084640036 + shape: (1000, 128) + 1: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.6649888157199006 + avg_max: 1.5336890253437756 + mean: 0.03420270613169851 + std: 0.6040079522209375 + shape: (1000, 128) + output: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) +mlp.0: + inputs: + 0: + min: -4.184338569641113 + max: 3.6927380561828613 + avg_min: -1.728540082865832 + avg_max: 1.5992073092841723 + mean: 0.014156845084580252 + std: 0.4875919206474072 + shape: (1000, 256) + output: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) +mlp.1: + inputs: + 0: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) + output: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) +mlp.2: + inputs: + 0: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) + output: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) +mlp_relu.0: + inputs: + 0: + min: -25.551782608032227 + max: 30.255319595336914 + avg_min: -10.410937029339797 + avg_max: 12.199230058019891 + mean: -1.3596681231175995 + std: 3.435064250450655 + shape: (1000, 256) + output: + min: 0.0 + max: 30.255319595336914 + avg_min: 0.0 + avg_max: 12.199230058019891 + mean: 0.6925220241296313 + std: 1.6946167814794522 + shape: (1000, 256) +mlp_relu.1: + inputs: + 0: + min: -231.78152465820312 + max: 82.65782165527344 + avg_min: -60.22239387931445 + avg_max: 21.538379845476946 + mean: -11.969065733546032 + std: 16.571529820325505 + shape: (1000, 128) + output: + min: 0.0 + max: 82.65782165527344 + avg_min: 0.0 + avg_max: 21.538379845476946 + mean: 1.362308199108358 + std: 3.989317218613674 + shape: (1000, 128) +mlp_relu.2: + inputs: + 0: + min: -235.94625854492188 + max: 203.7071990966797 + avg_min: -54.71186078950498 + avg_max: 29.957007628019554 + mean: -7.084710118819055 + std: 21.72214522135367 + shape: (1000, 64) + output: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) +final_mlp: + inputs: + 0: + min: 0.0 + max: 203.7071990966797 + avg_min: 0.0 + avg_max: 29.95700772112138 + mean: 4.937804440873921 + std: 11.42688696572864 + shape: (1000, 64) + output: + min: -283.5218200683594 + max: 40.86720275878906 + avg_min: -60.916563263919905 + avg_max: 6.9745166829700596 + mean: -22.90735553673859 + std: 35.900340843176544 + shape: (1000, 1) +final_mf: + inputs: + 0: + min: -6.388758659362793 + max: 7.461198329925537 + avg_min: -1.4091147140708975 + avg_max: 1.190038402591932 + mean: -0.03373257358976464 + std: 0.47984411373143276 + shape: (1000, 64) + output: + min: -54.07410430908203 + max: 47.101890563964844 + avg_min: -16.45866186288787 + avg_max: 7.688797266098178 + mean: -4.42458063103854 + std: 9.54055961838881 + shape: (1000, 1) +final_add: + inputs: + 0: + min: -54.07410430908203 + max: 47.101890563964844 + avg_min: -16.45866186288787 + avg_max: 7.688797266098178 + mean: -4.42458063103854 + std: 9.54055961838881 + shape: (1000, 1) + 1: + min: -283.5218200683594 + max: 40.86720275878906 + avg_min: -60.916563263919905 + avg_max: 6.9745166829700596 + mean: -22.90735553673859 + std: 35.900340843176544 + shape: (1000, 1) + output: + min: -264.23663330078125 + max: 10.719744682312012 + avg_min: -64.09727644866952 + avg_max: 4.514405736883009 + mean: -27.33193622736199 + std: 29.674832823876194 + shape: (1000, 1) +sigmoid: + inputs: + 0: + min: -264.23663330078125 + max: 10.719744682312012 + avg_min: -64.09727644866952 + avg_max: 4.514405736883009 + mean: -27.33193622736199 + std: 29.674832823876194 + shape: (1000, 1) + output: + min: 0.0 + max: 0.9999779462814331 + avg_min: 7.221196356095192e-08 + avg_max: 0.9796236107793396 + mean: 0.025780336994990046 + std: 0.11967490732565136 + shape: (1000, 1) diff --git a/examples/ncf/requirements.txt b/examples/ncf/requirements.txt deleted file mode 100644 index cfa65ea03548eeef5e7b2d7d60ea620106ff8e49..0000000000000000000000000000000000000000 --- a/examples/ncf/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -tqdm==4.20.0 -scipy -pandas diff --git a/requirements.txt b/requirements.txt index 06e56c4711c81dce3432bb7ce938c2baea71fe3e..c69c7065753c638ef5c5a77c9d7077fbc6855c43 100755 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ xlsxwriter>=1.1.1 pretrainedmodels==0.7.4 scikit-learn==0.21.2 gym==0.12.5 +tqdm==4.33.0