目录
(1)图像-文本对比损失(Image-Text Contrastive Loss, ITC)
(2)图像-文本匹配损失(Image-Text Matching Loss, ITM)
(3)语言建模损失(Language Modeling Loss, LM)
(1)Captioning and Filtering (CapFilt)
(3)视觉问答(Visual Question Answering, VQA)
(4)多模态特征提取(Feature Extraction)
BLIP使用一个模型实现了自然语言理解任务和自然语言生成任务。
1、Motivation
1、从模型角度来看,大多数方法要么采用基于编码器的模型,要么采用编码器-解码器模型。编码器的模型不太容易直接迁移到文本生成的任务中,如图像标题(image captioning)等;编码器—解码器模型还没有被成功用于图像-文本检索任务。
2、从数据角度来看,大多数sota的方法,如CLIP都是对从网上收集的图像—文本对(image-text pair)进行预训练。尽管可以通过扩大数据集的规模来获得性能上的提高,但研究结果显示,有噪声的网络文本对于视觉语言学习来说只能得到次优的结果。
2、网络结构
为了预训练一个具有理解和生成能力的统一模型,论文提出了多模态混合编码器-解码器(Multimodal mixture of Encoder-Decoder,MED),通过一个“模型”处理多个子任务。
1、单模态编码器(Unimodal encoder),对图像和文本分别进行编码。文本编码器(text encoder)与BERT相同,在文本输入的开头附加一个[CLS]标记,以总结句子。图像编码器直接使用ViT,同样使用附加的 [CLS] 标记来表示全局图像特征。
2、基于图像的文本编码器(Image-grounded text encoder),通过在自注意力(SA)层和前馈网络(FFN)之间为文本编码器的每个Transformer块插入一个额外的交叉注意力(CA)层来注入视觉信息。一个特定任务的[Encode]标记被附加到文本上,[Encode]的输出embedding被用作图像-文本对的多模态表示。
3、基于图像的文本解码器(Image-grounded text decoder),使用因果自注意力层(causal self-attention layer)替代编码器中的双向自注意力层。用[Decode]标记来表示一个序列的开始,end-of-sequence token用于标记序列的结束。
3、损失函数
(1)图像-文本对比损失(Image-Text Contrastive Loss, ITC)
集成 ALBEF 中的 ITC损失。(输出图文的相似度)优化vision-transormer + text-transormer,让匹配的图文对有较高相似度的表达(用了soft labels),多模态中的经典loss->使其互信息最大化;
ITC Loss 的全称是 Image-Text Contrastive Loss,为了在融合之前学习更好的unimodal表示,它学习,这里的
和
函数是给cls token embedding降维的线性层。另一方面,文图对会进入一个momentum unimodal encoders(这个结构的作用是通过结合过去更新中积累的知识,帮助稳定和提高学习表示的质量),变成
和
。
计算图像与文本的特征相似度:
对于每个图像和文本,计算softmax归一化的图像到文本和文本到图像相似度为:
其中是可学习的参数。令onehot相似度的真实值是
和
,真值中负样本对的概率为0,正样本对的概率为1,ITC loss为p和y的交叉熵:
代码实现:
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
image_embeds = self.visual_encoder(image) # [1, 3, 224, 224] → [1, 197, 768]
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) # [1, 197],v_cls
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) # [1, 768] → [1, 256],gv
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
return_tensors="pt").to(image.device) # [cls] + caption + [End] + [0, 0, ……]
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text') # [1, 30, 768]
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) # [1, 768] → [1, 256],gw(w_cls)
# get momentum features
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m(image) # [1, 197, 768]
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) # [1, 256]
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) # [256, 1] + [256, 57600] = [256, 57601],g'v(v'_cls)
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text') # [1, 30, 768]
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) # [1, 256]
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) # [256, 1] + [256, 57600] = [256, 57601],g'w(w'_cls)
sim_i2t_m = image_feat_m @ text_feat_all / self.temp # [1, 57601]
sim_t2i_m = text_feat_m @ image_feat_all / self.temp # [1, 57601]
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) # [1, 57601]
sim_targets.fill_diagonal_(1) # 对角线1,[1, 57601]
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets # [1, 57601]
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets # [1, 57601]
# image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
sim_i2t = image_feat @ text_feat_all / self.temp # [1, 57601]
sim_t2i = text_feat @ image_feat_all / self.temp # [1, 57601]
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
loss_ita = (loss_i2t+loss_t2i)/2
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
(2)图像-文本匹配损失(Image-Text Matching Loss, ITM)
目的是学习图像-文本的多模态表示以捕捉视觉和语言之间的细粒度对齐。(输出图文是否匹配的True/False)优化Image-grounded text encoder,学习图文的细粒度匹配的二分类,采用了hard negative mining strategy(将ITC任务中容易判断错的样本当作hard negative sample);
ITM是一个二元分类任务,模型根据多模态特征使用一个ITM头(线性层)来预测一个图像-文本对是positive(匹配的)还是negative(不匹配的)。
代码实现:
###============== Image-text Matching ===================###
pdb.set_trace()
encoder_input_ids = text.input_ids.clone() # (30)
encoder_input_ids[:,0] = self.tokenizer.enc_token_id # [cls] = [ENC]
# forward the positve image-text pair
bs = image.size(0) # 1
output_pos = self.text_encoder(encoder_input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
) # output_pos.last_hidden_state.shape = [1, 30, 768]
with torch.no_grad():
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
weights_t2i.fill_diagonal_(0)
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
weights_i2t.fill_diagonal_(0)
# select a negative image for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item() # 从 weights_t2i[b]这个概率分布中进行采样,得到一个索引,用于选择负样本图片的特征向量
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
# select a negative text for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(encoder_input_ids[neg_idx])
text_atts_neg.append(text.attention_mask[neg_idx])
text_ids_neg = torch.stack(text_ids_neg,dim=0)
text_atts_neg = torch.stack(text_atts_neg,dim=0)
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
output_neg = self.text_encoder(text_ids_all,
attention_mask = text_atts_all,
encoder_hidden_states = image_embeds_all,
encoder_attention_mask = image_atts_all,
return_dict = True,
)
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) # [2, 768]
vl_output = self.itm_head(vl_embeddings) # [2, 2]
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
dim=0).to(image.device) # 生成对应的标签,1 表示正样本,0 表示负样本。
loss_itm = F.cross_entropy(vl_output, itm_labels)
(3)语言建模损失(Language Modeling Loss, LM)
(生成图像的文本描述)优化image-grounded text decoder,学习如何从给定图片生成连贯的文本描述,采用交叉熵损失以自回归的方式最大化对应文本概率。解码器旨在生成给定图像的文本描述,优化了网络输出与给定文本labels的距离。因果自注意力mask掉句子后半部分,然后用前半部分去预测句子后面内容。
decoder_input_ids = text.input_ids.clone()
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
decoder_output = self.text_decoder(decoder_input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss
return loss_ita, loss_itm, loss_lm
sequence_output = outputs[0] # [1, 30, 768]
prediction_scores = self.cls(sequence_output) # [1, 30, 30524]
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() # [1, 29, 30524]
labels = labels[:, 1:].contiguous() # [1, 29]
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if reduction=='none':
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
ITM和ITC的区别:
- ITM只推开正负样本;ITC既推开正负样本,又拉进正样本。原文是说关注的粒度不同;
- 感觉另一个方面是需要有一个桥接的过程,因为text self-attention不能看image信息,所以ITC不能引入cross-attention;但这样cross-attention就缺失了正负样本对比学习的过程,ITM相当于一个不完全直接对比的手段去实现这个过程。
4、下游任务
(1)Captioning and Filtering (CapFilt)
motivation:
- 有图像——文本对(Ih,Th)类似 COCO 的标注数据集较少
- CLIP中使用的网络收集的图片文本对通常不能准确地描述图像的视觉内容,这使得它们成为一个嘈杂的信号,对于学习视觉语言对齐来说是次优的。
Method:
- Captioner为网络图像生成描述文本,Filter用来过滤不匹配的文本图像对。
- Captioner和Filter都是从同一个预训练的MED模型初始化的,在COCO数据集上单独进行微调。
- Captioner以LM为目标进行微调,使用text decoder对给定的图像解码生成caption。
- Filter以ITC和ITM的目标进行微调,以学习文本是否与图像匹配,该Filter去除原始网络文本和合成文本中的噪音文本,如果ITM头预测一个文本与图像不匹配,则该文本被认为是噪音。
- 过滤后的图像——文本对与人类注释的对相结合,形成一个新的数据集,用于预训练新的模型。
(2)多模态特征提取(Feature Extraction)
image_size = [1, 3, 224, 224]
text = 'a woman sitting on the beach with a dog'
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode=='image':
pdb.set_trace()
# return image features
image_embeds = self.visual_encoder(image) # [1, 197, 768]
return image_embeds
elif mode=='text':
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
return text_output.last_hidden_state # [1, 11, 768]
elif mode=='multimodal':
# return multimodel features
image_embeds = self.visual_encoder(image) # [1, 197, 768]
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) # [1, 197]
text.input_ids[:,0] = self.tokenizer.enc_token_id # [Encoder] + text[10]
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
return output.last_hidden_state # [1, 11, 768]
(3)图文匹配(Image-Text Matching)
image_size = [1, 3, 384, 384]
text = 'a woman sitting on the beach with a dog'
itm_output = model(image, caption,match_head='itm')
itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
print('The image and text is matched with a probability of %.4f'%itm_score)
itc_score = model(image,caption,match_head='itc')
print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)
image_embeds = self.visual_encoder(image) # [1, 577, 768]
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) # [1, 577]
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device) # [CLS] + tokens + [end] + [0……0], [1, 35]
if match_head=='itm': # cross-attention(ITM loss)
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
) # output.last_hidden_state.shape = [1, 35, 768]
itm_output = self.itm_head(output.last_hidden_state[:,0,:]) # 取[cls]做全局特征
return itm_output # [1, 2],match / not match
elif match_head=='itc': # self-attention(ITC loss)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text') # [1, 35, 768]
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) # [1, 256]
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) # [1, 256]
sim = image_feat @ text_feat.t() # (1)
return sim
(4)图像描述(Image Captioning)
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image) # [1, 3, 384, 384] → [1, 577, 768]
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) # [3, 577, 768]
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) # [3, 577],生成attention mask
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
prompt = [self.prompt] * image.size(0) # ['a picture of '] * b
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) # [cls] + tokens + [end], [1, 5]
input_ids[:,0] = self.tokenizer.bos_token_id # [DEC] + tokens + [end], [1, 5]
input_ids = input_ids[:, :-1] # # [DEC] + tokens, [1, 4]
if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs) # [dec] + tokens + [end], [1, 6]
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):]) # 去掉'a picture of'
return captions
(5)图像问答(Visual Question Answering, VQA)
image_size = [3, 480, 480]
question = "Who is in the picture?"
image_embeds = self.visual_encoder(image) # [1, 901, 768]
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) # [1, 901]
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
return_tensors="pt").to(image.device) # [cls] + tokens + [end], [1, 8]
question.input_ids[:,0] = self.tokenizer.enc_token_id # [dec] + tokens + [end], [1, 8]
if train:
'''
n: number of answers for each question
weights: weight for each answer
'''
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
answer.input_ids[:,0] = self.tokenizer.bos_token_id
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) # 计算损失时忽略padding
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
question_states = [] # hidden state
question_atts = [] # attention mask
for b, n in enumerate(n):
question_states += [question_output.last_hidden_state[b]]*n
question_atts += [question.attention_mask[b]]*n
question_states = torch.stack(question_states,0)
question_atts = torch.stack(question_atts,0)
answer_output = self.text_decoder(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = answer_targets,
return_dict = True,
reduction = 'none',
)
loss = weights * answer_output.loss
loss = loss.sum()/image.size(0)
return loss
else:
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True) # question_output.last_hidden_state = [1, 8, 768]
if inference=='generate':
num_beams = 3
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) # [3, 8, 768],产生多个可能的结果
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) # [3, 8]
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) # [dec]
outputs = self.text_decoder.generate(input_ids=bos_ids,
max_length=10,
min_length=1,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
**model_kwargs) # [cls] + tokens + [end]
answers = []
for output in outputs:
answer = self.tokenizer.decode(output, skip_special_tokens=True)
answers.append(answer)
return answers