From 81787436cb6d5c3fbf2e857ef44567aa85158d71 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 23 Jul 2019 17:14:53 +0300
Subject: [PATCH] NCF changes to make it compatible with latest changes in
 master

* Pass the 'sigmoid' flag in NeuMF.forward as a bool tensor instead of
  a simple boolean. Required to make the model traceable (it?d be better
  to not have it an argument of forward at all, but keeping changes to
  a minimum)
* Call prepare_model with dummy_input
---
 examples/ncf/ncf.py   | 7 ++++---
 examples/ncf/neumf.py | 2 +-
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py
index a32fa98..b3148ea 100644
--- a/examples/ncf/ncf.py
+++ b/examples/ncf/ncf.py
@@ -103,7 +103,7 @@ def predict(model, users, items, batch_size=1024, use_cuda=True):
                 if use_cuda:
                     x = x.cuda(async=True)
                 return torch.autograd.Variable(x)
-            outp = model(proc(user), proc(item), sigmoid=True)
+            outp = model(proc(user), proc(item), torch.tensor([True], dtype=torch.bool))
             outp = outp.data.cpu().numpy()
             preds += list(outp.flatten())
         return preds
@@ -354,7 +354,8 @@ def main():
         if args.quantize_eval and args.qe_calibration is None:
             model.cpu()
             quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args)
-            quantizer.prepare_model()
+            dummy_input = (torch.tensor([1]), torch.tensor([1]), torch.tensor([True], dtype=torch.bool))
+            quantizer.prepare_model(dummy_input)
             model.cuda()
 
         distiller.utils.assign_layer_fq_names(model)
@@ -406,7 +407,7 @@ def main():
             if compression_scheduler:
                 compression_scheduler.on_minibatch_begin(epoch, batch_index, steps_per_epoch, optimizer)
 
-            outputs = model(user, item)
+            outputs = model(user, item, torch.tensor([False], dtype=torch.bool))
             loss = criterion(outputs, label)
 
             if compression_scheduler:
diff --git a/examples/ncf/neumf.py b/examples/ncf/neumf.py
index dbf71b3..93b50ca 100644
--- a/examples/ncf/neumf.py
+++ b/examples/ncf/neumf.py
@@ -98,7 +98,7 @@ class NeuMF(nn.Module):
 
         super(NeuMF, self).load_state_dict(state_dict, strict)
 
-    def forward(self, user, item, sigmoid=False):
+    def forward(self, user, item, sigmoid):
         xmfu = self.mf_user_embed(user)  # .to(self.post_embed_device)
         xmfi = self.mf_item_embed(item)  # .to(self.post_embed_device)
         xmf = self.mf_mult(xmfu, xmfi)
-- 
GitLab