主要代码段解析:
# forward the positve image-text pair
# 正向传播正面的图像文本对
output_pos = self.text_encoder.bert(encoder_embeds=text_embeds,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
mode='fusion',
)
with torch.no_grad():
bs = image.size(0) # 获取批量大小
weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) # 对image到text的相似度进行softmax,沿着第二个维度计算
weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) # 对text到image的相似度进行softmax,沿着第二个维度计算
weights_i2t.fill_diagonal_(0) # 将权重矩阵的对角线设为0
weights_t2i.fill_diagonal_(0) # 将权重矩阵的对角线设为0
# select a negative image for each text
# 为每个文本选择一个负面的图像
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item() # 根据权重选择负面图像的索引
image_embeds_neg.append(image_embeds[neg_idx]) # 添加负面图像到列表
image_embeds_neg = torch.stack(image_embeds_neg, dim=0) # 将负面图像张量堆叠起来
# select a negative text for each image
# 为每张图像选择一个负面的文本
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item() # 根据权重选择负面文本的索引
text_embeds_neg.append(text_embeds[neg_idx]) # 添加负面文本到列表
text_atts_neg.append(text.attention_mask[neg_idx]) # 添加负面文本的注意力掩码到列表
text_embeds_neg = torch.stack(text_embeds_neg, dim=0) # 将负面文本张量堆叠起来
text_atts_neg = torch.stack(text_atts_neg, dim=0) # 将负面文本的注意力掩码张量堆叠起来
text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) # 拼接所有的文本张量
text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) # 拼接所有的文本的注意力掩码张量
image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) # 拼接所有的图像张量
image_atts_all = torch.cat([image_atts, image_atts], dim=0) # 拼接所有的图像的注意力掩码张量
output_neg = self.text_encoder.bert(encoder_embeds=text_embeds_all,
attention_mask=text_atts_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=True,
mode='fusion',
)
vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg.last_hidden_state[:, 0, :]], dim=0) # 拼接正负样本的嵌入表示
vl_output = self.itm_head(vl_embeddings) # 输入到信息论训练头部
itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], # 创建信息论训练标签
dim=0).to(image.device) # 将标签转移到相同的设备上
loss_itm = F.cross_entropy(vl_output, itm_labels) # 计算信息论训练损失