推荐系统对比学习中的损失函数

一、Info Noise-contrastive estimation(Info NCE)

  最近在基于对比学习做实验,github有许多实现,虽然直接套用即可,但是细看之下,损失函数部分甚是疑惑,故学习并记录于此。关于对比学习的内容网络上已经有很多内容了,因此不再赘述。本文重在对InfoNCE的两种实现方式的记录。

一、Info Noise-contrastive estimation(Info NCE)

1. 实现

MoCo源码的\moco\builder.py中,实现如下:

# compute logits

# Einstein sum is more intuitive

# positive logits: Nx1

l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

# negative logits: NxK

l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

# logits: Nx(1+K)

logits = torch.cat([l_pos, l_neg], dim=1)

# apply temperature

logits /= self.T

# labels: positive key indicators

labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

...

return logits, labels

这里的变量logits的意义我也查了一下:是未进入softmax的概率

这段代码根据注释即可理解:l_pos表示正样本的得分,l_neg表示所有负样本的得分,logits表示将正样本和负样本在列上cat起来之后的值。值得关注的是,labels的数值,是根据logits.shape[0]的大小生成的一组zero。也就是大小为batch_size的一组0。

  接下来看损失函数部分,

# define loss function (criterion) and optimizer

criterion = nn.CrossEntropyLoss().cuda(args.gpu)

...

# compute output

output, target = model(im_q=images[0], im_k=images[1])

loss = criterion(output, target)

这里直接对输出的logits和生成的labels计算交叉熵,然后就是模型的loss。这里就是让我不是很理解的地方。先将疑惑埋在心里~

二、HCL

1  实现

  但是在这篇文章的实现中,\image\main.py:

def criterion(out_1,out_2,tau_plus,batch_size,beta, estimator):

# neg score

out = torch.cat([out_1, out_2], dim=0)

neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)

old_neg = neg.clone()

mask = get_negative_mask(batch_size).to(device)

neg = neg.masked_select(mask).view(2 * batch_size, -1)



# pos score

pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)

pos = torch.cat([pos, pos], dim=0)



# negative samples similarity scoring

if estimator=='hard':

    N = batch_size * 2 - 2

    imp = (beta* neg.log()).exp()

    reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)

    Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)

    # constrain (optional)

    Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

elif estimator=='easy':

    Ng = neg.sum(dim=-1)

else:

    raise Exception('Invalid estimator selected. Please use any of [hard, easy]')

    

# contrastive loss

loss = (- torch.log(pos / (pos + Ng) )).mean()



return loss

三、文字解释

  既然是同一种方法的两种实现,已经理解了第二种实现(HCL)。那么,问题就出在了:不理解第一种实现的label为何要这样生成? 交叉熵的label的作用是:将label作为索引,来取得x xx中的项(x [ c l a s s ] x[class]x[class]),因此,这些项就是label。而倘若label是全0的项,那么其含义为:x xx中的第一列为label(正样本),其他列就是负样本。然后带入公式(3)中计算,即可得到交叉熵下的loss值。

  而对于HCL的实现方式,是直接将InfoNCE拆解开来,使用正样本的得分和负样本的得分来计算。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值