多模态text-image模型之ITM loss(blip)

主要代码:

# forward the positve image-text pair
# 正向传播正面的图像文本对
output_pos = self.text_encoder.bert(encoder_embeds=text_embeds, 
                                    attention_mask=text.attention_mask,
                                    encoder_hidden_states=image_embeds,
                                    encoder_attention_mask=image_atts,      
                                    return_dict=True,
                                    mode='fusion',
                                   )            
with torch.no_grad():
    bs = image.size(0)  # 获取批量大小          
    weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1)  # 对image到text的相似度进行softmax,沿着第二个维度计算
    weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1)  # 对text到image的相似度进行softmax,沿着第二个维度计算

    weights_i2t.fill_diagonal_(0)  # 将权重矩阵的对角线设为0
    weights_t2i.fill_diagonal_(0)  # 将权重矩阵的对角线设为0

# select a negative image for each text
# 为每个文本选择一个负面的图像
image_embeds_neg = []    
for b in range(bs):
    neg_idx = torch.multinomial(weights_t2i[b], 1).item()  # 根据权重选择负面图像的索引
    image_embeds_neg.append(image_embeds[neg_idx])  # 添加负面图像到列表
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)  # 将负面图像张量堆叠起来

# select a negative text for each image
# 为每张图像选择一个负面的文本
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_i2t[b], 1).item()  # 根据权重选择负面文本的索引
    text_embeds_neg.append(text_embeds[neg_idx])  # 添加负面文本到列表
    text_atts_neg.append(text.attention_mask[neg_idx])  # 添加负面文本的注意力掩码到列表
text_embeds_neg = torch.stack(text_embeds_neg, dim=0)  # 将负面文本张量堆叠起来
text_atts_neg = torch.stack(text_atts_neg, dim=0)  # 将负面文本的注意力掩码张量堆叠起来

text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)  # 拼接所有的文本张量
text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)  # 拼接所有的文本的注意力掩码张量

image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)  # 拼接所有的图像张量
image_atts_all = torch.cat([image_atts, image_atts], dim=0)  # 拼接所有的图像的注意力掩码张量

output_neg = self.text_encoder.bert(encoder_embeds=text_embeds_all, 
                                    attention_mask=text_atts_all,
                                    encoder_hidden_states=image_embeds_all,
                                    encoder_attention_mask=image_atts_all,      
                                    return_dict=True,
                                    mode='fusion',
                                   )                         

vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg.last_hidden_state[:, 0, :]], dim=0)  # 拼接正负样本的嵌入表示
vl_output = self.itm_head(vl_embeddings)  # 输入到信息论训练头部            

itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],  # 创建信息论训练标签
                       dim=0).to(image.device)  # 将标签转移到相同的设备上
loss_itm = F.cross_entropy(vl_output, itm_labels)  # 计算信息论训练损失     

参考:多模态text-image模型之ITM loss-CSDN博客

求Loss的代码:

loss_itm = F.cross_entropy(vl_output, itm_labels)

 

  1. vl_output 是模型输出的分类得分,itm_labels 是每个样本的真实标签。

  2. vl_output:模型输出的是经过训练头部(self.itm_head)的得分,这个头部是一个全连接层,用于将模型学到的特征映射到正面和负面类别的得分。

  3. itm_labels:模型对应的标签,包含了每个样本的真实标签。torch.ones(bs, dtype=torch.long) 是正面样本的标签,设为 1,torch.zeros(2 * bs, dtype=torch.long) 是负面样本的标签,设为 0。然后,使用 torch.cat 函数将这些标签连接起来,形成一个完整的标签张量。

  4. loss_itm:通过调用 F.cross_entropy 函数计算模型输出和真实标签之间的交叉熵损失。这个损失反映了模型预测和实际标签之间的差异,用于指导模型参数的更新,以便更好地区分正面和负面样本。

  • 7
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值