Skip to content
Snippets Groups Projects
Unverified Commit ca96658d authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Remove debug comments

parent 6b39f1fa
No related branches found
No related tags found
No related merge requests found
...@@ -75,8 +75,6 @@ class NeuMF(nn.Module): ...@@ -75,8 +75,6 @@ class NeuMF(nn.Module):
lecunn_uniform(self.final_mlp) lecunn_uniform(self.final_mlp)
lecunn_uniform(self.final_mf) lecunn_uniform(self.final_mf)
# self.post_embed_device = torch.device('cpu')
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
if 'final.weight' in state_dict and self.split_final: if 'final.weight' in state_dict and self.split_final:
# Loading no-split checkpoint into split model # Loading no-split checkpoint into split model
...@@ -99,53 +97,24 @@ class NeuMF(nn.Module): ...@@ -99,53 +97,24 @@ class NeuMF(nn.Module):
super(NeuMF, self).load_state_dict(state_dict, strict) super(NeuMF, self).load_state_dict(state_dict, strict)
def forward(self, user, item, sigmoid): def forward(self, user, item, sigmoid):
xmfu = self.mf_user_embed(user) # .to(self.post_embed_device) xmfu = self.mf_user_embed(user)
xmfi = self.mf_item_embed(item) # .to(self.post_embed_device) xmfi = self.mf_item_embed(item)
xmf = self.mf_mult(xmfu, xmfi) xmf = self.mf_mult(xmfu, xmfi)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'mf_mult'), xmf.cpu().detach().numpy())
xmlpu = self.mlp_user_embed(user) # .to(self.post_embed_device) xmlpu = self.mlp_user_embed(user)
xmlpi = self.mlp_item_embed(item) # .to(self.post_embed_device) xmlpi = self.mlp_item_embed(item)
xmlp = self.mlp_concat(xmlpu, xmlpi) xmlp = self.mlp_concat(xmlpu, xmlpi)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'mlp_concat'), xmlp.cpu().detach().numpy())
for i, (layer, act) in enumerate(zip(self.mlp, self.mlp_relu)): for i, (layer, act) in enumerate(zip(self.mlp, self.mlp_relu)):
xmlp = layer(xmlp) xmlp = layer(xmlp)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'mlp.{}'.format(i)), xmlp.detach().cpu().numpy())
xmlp = act(xmlp) xmlp = act(xmlp)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'mlp_relu.{}'.format(i)), xmlp.detach().cpu().numpy())
if not self.split_final: if not self.split_final:
x = self.final_concat(xmf, xmlp) x = self.final_concat(xmf, xmlp)
x = self.final(x) x = self.final(x)
else: else:
xmf = self.final_mf(xmf) xmf = self.final_mf(xmf)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'final_mf'), xmf.detach().cpu().numpy())
xmlp = self.final_mlp(xmlp) xmlp = self.final_mlp(xmlp)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'final_mlp'), xmlp.detach().cpu().numpy())
x = self.final_add(xmf, xmlp) x = self.final_add(xmf, xmlp)
# @DEBUG
# np.save(os.path.join(msglogger.logdir, 'final_add'), x.detach().cpu().numpy())
if sigmoid: if sigmoid:
x = self.sigmoid(x) x = self.sigmoid(x)
return x return x
# def to_cuda(self, device=None, embeds_on_gpu=True):
# self.post_embed_device = device if device is not None else torch.device('cuda')
#
# if embeds_on_gpu:
# return self.cuda(device=device)
#
# for m in self.modules():
# if isinstance(m, nn.Embedding):
# m.cpu()
# else:
# m.cuda(device=device)
#
# return self
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment