多模态text-image模型之ITC loss (blip)

ALBEF代码中ITC loss对应的主要代码:

        sim_i2t = image_feat @ text_feat_all / self.temp 
        sim_t2i = text_feat @ image_feat_all / self.temp 
        # image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
        # self.temp = nn.Parameter(torch.ones([]) * config['temp']) ,引入一个可学习的参数,可以对计算的结果进行缩放,从而调整模型
                             
        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
        # F.log_softmax(sim_i2t, dim=1)对sim_i2t的每一行进行log_softmax计算
        # sim_i2t和sim_i2t_targets的形状一样,sim_i2t_targets是真实值
        # F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets是矩阵按元素相乘
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
        # loss_t2i中的操作同上

        loss_ita = (loss_i2t+loss_t2i)/2 #求平均得到ITC Loss

参考:多模态text-image模型之ITC loss_itcloss-CSDN博客

Loss代码解读:

loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
  1. F.log_softmax(sim_i2t, dim=1):这一部分首先对输入的相似性得分 sim_i2t 在维度 1 上进行 softmax 操作,然后取对数。Softmax 操作将相似性得分转换为概率分布,而取对数操作有助于数值稳定性和数学推导。结果是一个经过对数变换的概率分布。

  2. * sim_i2t_targets:这一部分将对数概率分布与目标概率分布 sim_i2t_targets 相乘。目标概率分布通常是正样本的标签分布。这个操作相当于将模型预测的概率分布与真实标签进行对齐,以便计算损失。

  3. -torch.sum(...,dim=1):这一部分对结果进行求和,但是仅在维度 1 上求和。这相当于将每个样本的损失相加起来,但保持了 batch 的维度。

  4. .mean():最后,对所有样本的损失求均值,得到最终的损失值。这个操作将 batch 中所有样本的损失平均化,以得到一个可比较的损失值。

综合起来,这一句代码计算了图像到文本的对比学习损失,通过将模型预测的概率分布真实标签进行对齐,然后计算损失并求均值。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值