交叉熵(Cross entropy)和InfoNCE

多分类问题,一个样本的交叉熵损失函数为:

L C E = − ∑ c = 1 M y o , c l o g ( p o , c ) L_{CE}=-\sum^M_{c=1} y_{o,c} log(p_{o,c}) LCE=c=1Myo,clog(po,c)

其中:

  • M:类别数
  • y o , c : y o 是 o n e − h o t 编码的向量,代表这个样本的真实标签。 c 为某位置上分量 y_{o,c}:y_o是one-hot编码的向量,代表这个样本的真实标签。c为某位置上分量 yo,cyoonehot编码的向量,代表这个样本的真实标签。c为某位置上分量
  • p o , c : 模型预测样本 o 属于类别 c 的概率。 p_{o,c}:模型预测样本 o 属于类别 c 的概率。 po,c:模型预测样本o属于类别c的概率。
  • p o , c : 通常是 s o f t m a x 计算。再 t o r c h 中的 C E 方法自动对输入的 l o g i t s 先算 s o f t m a x 再算 C E p_{o,c}:通常是softmax计算。再torch中的CE方法自动对输入的logits先算softmax再算CE po,c:通常是softmax计算。再torch中的CE方法自动对输入的logits先算softmax再算CE

举例计算

假设的真实标签和模型预测:

  • 样本1的真实标签是类别3(one-hot编码向量为[0, 0, 1, 0, 0])。
  • 样本2的真实标签是类别1(one-hot编码向量为[1, 0, 0, 0, 0])。
  • 样本3的真实标签是类别5(one-hot编码向量为[0, 0, 0, 0, 1])。
  • 模型预测的概率分布为:
    • 样本1的预测概率:[0.1, 0.2, 0.5, 0.1, 0.1]。
    • 样本2的预测概率:[0.3, 0.3, 0.1, 0.2, 0.1]。
    • 样本3的预测概率:[0.2, 0.1, 0.1, 0.1, 0.5]。

计算交叉熵损失的步骤:

计算每个样本的损失
在这里插入图片描述

计算对数概率:

在这里插入图片描述

计算最终损失:

在这里插入图片描述

求和取平均

在这里插入图片描述

所以一个batch内的多分类问题中的CE损失公式可以表示为:
在这里插入图片描述

再来看看单个样本的InfoNCE:
在这里插入图片描述

在这里插入图片描述

出自MoCo:
可以理解为,对一个q,计算它和batch(K+1个)中每个k的相似度 q ∗ k i q*k_{i} qki得到一个相似度向量。我们最大化q和自己正样本k+的相似度。相当于把向量 q ∗ k i q*k_{i} qki看作CE的输入,one-hot向量为:正样本上为1,其他位置为0的。也就是对相似度做K+1的CE。使得和正样本的相似度最高。

batch大小为N的InfoNCE:
在这里插入图片描述

  • 21
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值