From 81787436cb6d5c3fbf2e857ef44567aa85158d71 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Tue, 23 Jul 2019 17:14:53 +0300 Subject: [PATCH] NCF changes to make it compatible with latest changes in master * Pass the 'sigmoid' flag in NeuMF.forward as a bool tensor instead of a simple boolean. Required to make the model traceable (it?d be better to not have it an argument of forward at all, but keeping changes to a minimum) * Call prepare_model with dummy_input --- examples/ncf/ncf.py | 7 ++++--- examples/ncf/neumf.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/ncf/ncf.py b/examples/ncf/ncf.py index a32fa98..b3148ea 100644 --- a/examples/ncf/ncf.py +++ b/examples/ncf/ncf.py @@ -103,7 +103,7 @@ def predict(model, users, items, batch_size=1024, use_cuda=True): if use_cuda: x = x.cuda(async=True) return torch.autograd.Variable(x) - outp = model(proc(user), proc(item), sigmoid=True) + outp = model(proc(user), proc(item), torch.tensor([True], dtype=torch.bool)) outp = outp.data.cpu().numpy() preds += list(outp.flatten()) return preds @@ -354,7 +354,8 @@ def main(): if args.quantize_eval and args.qe_calibration is None: model.cpu() quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) - quantizer.prepare_model() + dummy_input = (torch.tensor([1]), torch.tensor([1]), torch.tensor([True], dtype=torch.bool)) + quantizer.prepare_model(dummy_input) model.cuda() distiller.utils.assign_layer_fq_names(model) @@ -406,7 +407,7 @@ def main(): if compression_scheduler: compression_scheduler.on_minibatch_begin(epoch, batch_index, steps_per_epoch, optimizer) - outputs = model(user, item) + outputs = model(user, item, torch.tensor([False], dtype=torch.bool)) loss = criterion(outputs, label) if compression_scheduler: diff --git a/examples/ncf/neumf.py b/examples/ncf/neumf.py index dbf71b3..93b50ca 100644 --- a/examples/ncf/neumf.py +++ b/examples/ncf/neumf.py @@ -98,7 +98,7 @@ class NeuMF(nn.Module): super(NeuMF, self).load_state_dict(state_dict, strict) - def forward(self, user, item, sigmoid=False): + def forward(self, user, item, sigmoid): xmfu = self.mf_user_embed(user) # .to(self.post_embed_device) xmfi = self.mf_item_embed(item) # .to(self.post_embed_device) xmf = self.mf_mult(xmfu, xmfi) -- GitLab