ALBEF代码中ITC loss对应的主要代码:
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
# image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
# self.temp = nn.Parameter(torch.ones([]) * config['temp']) ,引入一个可学习的参数,可以对计算的结果进行缩放,从而调整模型
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
# F.log_softmax(sim_i2t, dim=1)对sim_i2t的每一行进行log_softmax计算
# sim_i2t和sim_i2t_targets的形状一样,sim_i2t_targets是真实值
# F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets是矩阵按元素相乘
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
# loss_t2i中的操作同上
loss_ita = (loss_i2t+loss_t2i)/2 #求平均得到ITC Loss
参考:多模态text-image模型之ITC loss_itcloss-CSDN博客
Loss代码解读:
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
-
F.log_softmax(sim_i2t, dim=1)
:这一部分首先对输入的相似性得分sim_i2t
在维度 1 上进行 softmax 操作,然后取对数。Softmax 操作将相似性得分转换为概率分布,而取对数操作有助于数值稳定性和数学推导。结果是一个经过对数变换的概率分布。 -
* sim_i2t_targets
:这一部分将对数概率分布与目标概率分布sim_i2t_targets
相乘。目标概率分布通常是正样本的标签分布。这个操作相当于将模型预测的概率分布与真实标签进行对齐,以便计算损失。 -
-torch.sum(...,dim=1)
:这一部分对结果进行求和,但是仅在维度 1 上求和。这相当于将每个样本的损失相加起来,但保持了 batch 的维度。 -
.mean()
:最后,对所有样本的损失求均值,得到最终的损失值。这个操作将 batch 中所有样本的损失平均化,以得到一个可比较的损失值。
综合起来,这一句代码计算了图像到文本的对比学习损失,通过将模型预测的概率分布与真实标签进行对齐,然后计算损失并求均值。