比如embedding维度为64,batch_size为512,负采样比例为1时一切都很简单:
user_embs.shape #(512,64)
pos_item_embs.shape #(512,64)
neg_item_embs.shape #(512,64)
pos_scores = torch.sum(torch.mul(user_emb, pos_item_emb), axis=1) #(512)
neg_s
比如embedding维度为64,batch_size为512,负采样比例为1时一切都很简单:
user_embs.shape #(512,64)
pos_item_embs.shape #(512,64)
neg_item_embs.shape #(512,64)
pos_scores = torch.sum(torch.mul(user_emb, pos_item_emb), axis=1) #(512)
neg_s