多模态text-image模型之ITC loss (blip)

ALBEF代码中ITC loss对应的主要代码:

        sim_i2t = image_feat @ text_feat_all / self.temp 
        sim_t2i = text_feat @ image_feat_all / self.temp 
        # image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
        # self.temp = nn.Parameter(torch.ones([]) * config['temp']) ,引入一个可学习的参数,可以对计算的结果进行缩放,从而调整模型
                             
        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
        # F.log_softmax(sim_i2t, dim=1)对sim_i2t的每一行进行log_softmax计算
        # sim_i2t和sim_i2t_targets的形状一样,sim_i2t_targets是真实值
        # F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets是矩阵按元素相乘
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
        # loss_t2i中的操作同上

        loss_ita = (loss_i2t+loss_t2i)/2 #求平均得到ITC Loss

参考:多模态text-image模型之ITC loss_itcloss-CSDN博客

Loss代码解读:

loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
  1. F.log_softmax(sim_i2t, dim=1):这一部分首先对输入的相似性得分 sim_i2t 在维度 1 上进行 softmax 操作,然后取对数。Softmax 操作将相似性得分转换为概率分布,而取对数操作有助于数值稳定性和数学推导。结果是一个经过对数变换的概率分布。

  2. * sim_i2t_targets:这一部分将对数概率分布与目标概率分布 sim_i2t_targets 相乘。目标概率分布通常是正样本的标签分布。这个操作相当于将模型预测的概率分布与真实标签进行对齐,以便计算损失。

  3. -torch.sum(...,dim=1):这一部分对结果进行求和,但是仅在维度 1 上求和。这相当于将每个样本的损失相加起来,但保持了 batch 的维度。

  4. .mean():最后,对所有样本的损失求均值,得到最终的损失值。这个操作将 batch 中所有样本的损失平均化,以得到一个可比较的损失值。

综合起来,这一句代码计算了图像到文本的对比学习损失,通过将模型预测的概率分布真实标签进行对齐,然后计算损失并求均值。

### ITC Loss 的概念与实现 ITC (Image-Text Contrastive) 损失是一种用于多模态学习中的对比学习技术,主要用于训练图像和文本之间的联合表示。其核心目标是使匹配的图像-文本对在嵌入空间中更接近,而非匹配的对则尽可能远离。 #### 对比学习的核心原理 对比学习通过最大化正样本对之间的一致性和最小化负样本对之间的相似度来优化模型参数。具体来说,在文图模型中,给定一批数据 \( \{(I_i, T_i)\} \),其中 \( I_i \) 表示第 \( i \) 张图片,\( T_i \) 是对应的文本描述,模型会分别提取它们的特征向量 \( f(I_i) \) 和 \( g(T_i) \)[^1]。这些特征通常被映射到同一维度的空间以便比较。 #### ITC Loss 数学表达 假设我们有一批大小为 \( N \) 的数据,则对于每张图片 \( I_i \),它与其他所有文本形成一对正样本和多个负样本。ITC Loss 可以形式化如下: \[ L_{ITC} = -\frac{1}{N}\sum_{i=1}^{N} \log{\frac{\exp(\text{sim}(f(I_i), g(T_i))/\tau)}{\sum_{j=1}^{N}\exp(\text{sim}(f(I_i), g(T_j))/\tau)}} \] 这里: - \( \text{sim}(a, b) \) 通常是两个向量的余弦相似度; - \( \tau \) 称为温度超参数,控制分布的锐利程度[^2]。 该公式的作用是对每个图像找到最可能配对的文本,并惩罚那些错误关联的情况。 #### 实现代码示例 以下是基于 PyTorch 的简单实现方式: ```python import torch import torch.nn.functional as F def itc_loss(image_embeddings, text_embeddings, temperature=0.07): """ 计算 Image-Text Contrastive Loss 参数: image_embeddings: 图像嵌入矩阵, shape=(batch_size, embedding_dim) text_embeddings: 文本嵌入矩阵, shape=(batch_size, embedding_dim) temperature: 温度超参数,默认值为0.07 返回: loss: 标量损失值 """ batch_size = image_embeddings.shape[0] # 归一化嵌入向量 image_embeddings_normed = F.normalize(image_embeddings, dim=-1) text_embeddings_normed = F.normalize(text_embeddings, dim=-1) # 计算相似度矩阵 logits_per_image = torch.matmul( image_embeddings_normed, text_embeddings_normed.t() ) / temperature # 构建 ground truth labels ground_truth_labels = torch.arange(batch_size, device=image_embeddings.device) # 使用交叉熵作为最终损失函数 loss_i = F.cross_entropy(logits_per_image, ground_truth_labels) loss_t = F.cross_entropy(logits_per_image.t(), ground_truth_labels) return (loss_i + loss_t) / 2 ``` 上述代码实现了双向对比学习过程,即不仅考虑从图像到文本的方向,还反向评估从文本到图像的效果。 ### 总结 通过对齐图像和文本的潜在表征并通过引入温度调节机制增强区分能力,ITC Loss 成为了构建高效跨模态检索系统的基石之一。这种设计使得即使是在大规模无监督场景下也能有效提升性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值