引言:很久之前读blip2,对ITC和ITM大致有个印象,一个对比学习,一个图文匹配的二分类,咋一听好像没什么难理解的,最近好好看了一下源码,觉得实现上很巧妙,值得与诸君共享
这里小编没有一句一句分析,直接源码+注释,觉得这样看比较方便,因为只分析ITC和ITM,所以这里只放了blip2里面的Blip2Qformer的forward函数内容,如有出入,还请各位小伙伴留言斧正!
Image-text Contrastive
###============== Image-text Contrastive ===================###
"""
因为在多张卡上训练,所以这里需要将所有卡上的图像特征收集起来,维度为[batch_size*num_gpu, num_query_tokens, embed_dim],
其中,num_query_tokens是视觉tokens数量,embed_dim是维度
"""
image_feats_all = concat_all_gather(
image_feats
) # [batch_size*num_gpu, num_query_tokens, embed_dim]
# 文本这一步操作与上述同理
text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim]
"""
求图像与所有文本的相似度
这里image_feats.unsqueeze(1)之后的维度是[batch_size,1, num_query_tokens, embed_dim]
text_feat_all.unsqueeze(-1)之后的维度是[batch_size*num_gpu, embed_dim,1]
为了求每个图像跟所有文本的相似度,图像特征[batch_size,1, num_query_tokens, embed_dim]第2个维度会被广播到batch_size*num_gpu变成[batch_size*,batch_size*num_gpu, num_query_tokens, embed_dim]
然后矩阵乘法会沿着image_feats和text_feat_all最后两个维度进行相乘,embed_dim维度相乘消失,所以得到的结果为[batch_size,batch_size*num_gpu, num_query_tokens,1]
相乘之后的结果再squeeze()就得到了[batch_size,batch_size*num_gpu, num_query_tokens]
"""
sim_q2t = torch.matmul(
image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
).squeeze() # [batch_size, batch_size*num_gpu, num_query_tokens]
"""
max(-1)表示在最后一个维度上,寻找最大值
也就是说,对每个图像到文本的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
"""
sim_i2t, _ = sim_q2t.max(-1)
# 通过温度参数self.temp进行相似度的缩放控制
sim_i2t = sim_i2t / self.temp
"""
求文本与所有图像的相似度
text_feat.unsqueeze(1).unsqueeze(1)之后的维度为[batch_size,1,1,embed_dim]
image_feats_all.permute(0, 2, 1)交换后面两个维度之后的特征维度为[batch_size*num_gpu, embed_dim, num_query_token]
同理,文本特征[batch_size,1,1,embed_dim]会广播第2个维度到batch_size*num_gpu,变成[batch_size,batch_size*num_gpu,1,embed_dim]
然后最后两个维度做矩阵乘法得到[batch_size,batch_size*num_gpu,1,num_query_token]
squeeze()之后的特征为[batch_size,batch_size*num_gpu,num_query_token]
"""
sim_t2q = torch.matmul(
text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
).squeeze()
# 对每个文本到图像的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp
rank = dist.get_rank()
bs = image.size(0)
"""
torch.linspace(start, end, steps, dtype=int)的作用是生成从 start 到 end 之间的 steps 个数值,并返回一个 1D 张量
这里用来生成多 GPU 训练中的标签(targets)索引,targets维度维[batch_size]
每个 GPU 进程(或 rank)负责处理自己的 batch,并为它分配唯一的索引序列
"""
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
image.device
)
if "image_id" in samples.keys(): # coco retrieval finetuning
# 对于包含图像 ID 的样本,使用基于相似度的目标分布计算损失
image_ids = samples["image_id"].view(-1, 1)
image_ids_all = concat_all_gather(image_ids)
pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
loss_itc = (loss_t2i + loss_i2t) / 2
else:
"""
否则,使用交叉熵计算损失
sim_i2t维度为[batch_size, batch_size*num_gpu],targets维度维[batch_size]
对于sim_i2t每个batch,targets都有唯一一个在 0 到 batch_size * num_gpu - 1 之间真实值,因此可以计算交叉熵,从而达到让正例更接近,负例更远的效果
"""
loss_itc = (
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
) / 2
Image-text Matching
###============== Image-text Matching ===================###
# 同上述
text_input_ids_world = concat_all_gather(text_tokens.input_ids)
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
image_embeds_world = all_gather_with_grad(image_embeds)
with torch.no_grad():
# 当有image_id时,作者把相似度矩阵里面image_ids相匹配的都mask掉了,即在后面计算的时候忽略样本自身的匹配
if "image_id" in samples.keys():
mask = torch.eq(image_ids, image_ids_all.t())
sim_t2i.masked_fill_(mask, -10000)
sim_i2t.masked_fill_(mask, -10000)
else:
# 与上面同理,将当前 GPU 进程处理的样本的索引范围填充为 -10000,即在后面计算的时候忽略样本自身的匹配
sim_t2i[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)
sim_i2t[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)
# 被masked的值和被fill_diagonal_(-10000),经过softmax之后都会接近于0
weights_t2i = F.softmax(sim_t2i, dim=1)
weights_i2t = F.softmax(sim_i2t, dim=1)
# 为每个文本选择一个负样本图像
image_embeds_neg = []
for b in range(bs):
"""
对每个batch的数据随机选择一个负样本
torch.multinomial从给定的概率分布中进行多项式分布抽样
weights_t2i[b]中值大的数,被采样的概率就大,上述对sim_t2i自身样本进行mask就是为了这里自身样本作为正样本不会被选择
"""
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# 为每个图像选择一个负样本文本
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(text_input_ids_world[neg_idx])
text_atts_neg.append(text_attention_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)
"""
这一步很妙!
将文本的两个正样本一个负样本进行拼接,为后续二分类做准备
至于为什么这么拼接,后面你就知道了
"""
text_ids_all = torch.cat(
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
) # pos, pos, neg
text_atts_all = torch.cat(
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
dim=0,
)
# 这一步是对query_tokens进行一些处理
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
image.device
)
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
"""
将图像的两个正样本一个负样本进行拼接,为后续二分类做准备
注意:文本拼接的顺序是:正样本,正样本,负样本
图像拼接的顺序是:正样本,负样本,正样本
它们只有第一个位置都是正样本,也即第一个位置是一对匹配的正例,后面两个位置都是一正一负是不匹配的,这样我们就可以通过判断它们匹不匹配来进行二分类学习,妙哉!
"""
image_embeds_all = torch.cat(
[image_embeds, image_embeds_neg, image_embeds], dim=0
) # pos, neg, pos
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
image.device
)
# 将拼接后的文本特征,图像特征以及相应的query_tokens输入到bert中进行分类预测
output_itm = self.Qformer.bert(
text_ids_all,
query_embeds=query_tokens_itm,
attention_mask=attention_mask_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=True,
)
# 取分类预测的结果
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
vl_output = self.itm_head(vl_embeddings)
logits = vl_output.mean(dim=1)
# 生成对应的真实标签,只有第一个batch文本对是匹配的,所以第一个batch的标签设置为1,其他都是0
itm_labels = torch.cat(
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
dim=0,
).to(image.device)
# 将预测结果和真实值输入到交叉熵损失函数中,进行二分类损失计算
loss_itm = F.cross_entropy(logits, itm_labels)