pytorch 交叉损失熵函数及CLIP中的对比函数

交叉损失熵函数

在深度学习进行分类任务时经常用到交叉损失熵函数:首先定义logic表示余弦相似度,labels表示真实标签。可以直接使用交叉熵损失函数

import torch.nn as nn
Loss = nn.CrossEntropyLoss(logic, labels)

其中logic是一个NXC大小的数组,N表示有多少个样本(BatchSize),C表示类别总数。例如logic=\begin{bmatrix} 0.5\ 0.4\ 0.2\ 0.7 \\ 0.2\ 0.5\ 0.3\ 0.8 \\0.1\ 0.5\ 0.4\ 0.1\end{bmatrix}

        表示有三个样本,四个类别。

        而labels是一个一维数组,大小为N。表示N个样本所对应的真实类别。如label=[3,0,1],表示一个有三个样本,第一个样本对应的真实标签为3,第二个真实标签为0,第三个真实标签为1.

交叉熵函数计算公式为:

loss = -\sum_{i}^{N} labels[i]*\ln logic[i]

其中,可以认为在pytorch函数内部,会将label转换为一个NXC大小的数组,并且只有在真实类别处为1,其它为0。例如会将label[3,0,1]转化为\begin{bmatrix} 0\ 0\ 0\ 1 \\ 1\ 0\ 0\ 0 \\0\ 1\ 0\ 0\end{bmatrix},则可以将公式化简为:

loss = \tfrac{-\sum_{i}^{N} \ln logic[i][j]}{N}

即将每个样本真实标签的概率求平均。而\ln logic可以近似为先对logic进行softmax,得到每一个样本对应不同类别的概率,之后在取log。参考:http://t.csdnimg.cn/49mKj

CLIP中的对比损失函数:

伪代码如下:

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# 分别提取图像特征和文本特征
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# 对两个特征进行线性投射,得到相同维度的特征,并进行l2归一化
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# 计算缩放的余弦相似度:[n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

# 对称的对比学习损失:等价于N个类别的cross_entropy_loss
labels = np.arange(n) # 对角线元素的labels
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

 可以看到在计算缩放的余弦相似度后,分别使用了两个交叉损失熵函数来分别计算图像和文本的损失。这里结合图像-文本对比训练图来理解:

 可以看到,每个对角线元素为正样本,可以设置为其对应的标签。所以label是一个从0到N-1对的数组。从行的角度出发,则可以计算图像的损失函数,从列的角度出发,则可以计算文本的损失函数。最后将二者的均值作为对比损失函数。

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值