【多模态】BLIP-2模型技术学习

前言

最近多模态模型特别火,从头开始学习!在前面写的几篇里面学习了MiniCPM-V、ViT、CLIP和BLIP之后,今天学习一下BLIP-2模型,记录学习过程,主要是模型架构、训练方式和相关源代码的理解。欢迎批评指正,一起学习~~

1 充分利用现有视觉和文本模型的多模态模型

  • BLIP2出自Salesforce发表在ICML 2023的论文BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
  • 通过加入一个Q-Former模块将视觉模块和LLM连接起来,预训练过程中只训练Q-Former,冻结视觉模块和LLM,充分利用现有模型能力【 Q-Former = query + transformer,会训练query和模型本身】
  • 包含2个阶段的预训练:(1) 视觉模块+Q-Former,训练Q-Former从图片中提取出和文本最相关的表征;(2) Q-Former+LLM,训练Q-Former产生出能让LLM看得懂的视觉表征
    在这里插入图片描述

2 Q-Former两阶段的预训练

  两阶段的预训练,这两个阶段是同一个Q-Former,先进行第一阶段训练,再进行第二阶段训练,这两个阶段数据流经过的模型不一样。

第一阶段

  • 第一阶段论文里面的图比较抽象,看着不太容易理解。总而言之,Q-Former接在图像编码器后,Q-Former自身充当文本模型
  • 实际上Q-Former就是一个bert,没有2个encoder;左边部分是输入会经过cross-attention层图文直接交互,右边部分没有cross-attention层
  • 这一阶段有3个训练loss,和BLIP一致,训练query和模型从图片中提取出和文本最相关的表征:
      (1) Image-Text Contrastive Loss (ITC) – 图文对比学习,在隐空间对齐图片编码和文本编码
      (2) Image-Text Matching Loss (ITM) - 二分类任务,让模型判断图文是否一致
      (3) Image-grounded Text Generation (ITG) - 下一词预测,让模型学会给定图片输出caption
    在这里插入图片描述

第二阶段

  • 第二阶段,Q-Former接在图像编码器后,同时Q-Former的query对应的输出接到LLM上
  • 这一个阶段训练Q-Former输出让LLM能看得懂的中间结果,是Language Modeling Loss
    在这里插入图片描述

3 官方代码结构

官方代码在:https://github.com/salesforce/LAVIS/tree/main/lavis,其中和BLIP-2有关的需要关注:

  • 训练脚本:lavis/run_scripts/blip2/train/
  • 训练配置: lavis/projects/blip2/train/
  • 模型配置(第二阶段模型): lavis/configs/models/blip2/,里面有例如blip2_caption_opt6.7b.yaml
  • 模型代码: lavis/models/blip2_models/,关注blip2.py, Q-Former.py, blip2_qformer.py(第1阶段预训练),blip2_opt.py(第2阶段预训练)

训练调用train.py,loss是在每个模型类里面返回的,下载完整项目在pycharm里面看会清晰很多

4 第一阶段预训练loss计算——ITC图文对比学习Loss

4.1 基础数据准备

核心代码在blip2_qfomer.py的forward()函数,同时训练query和bert,这个函数会返回计算的3个loss,首先会获取基础的图片编码和文本编码:

  • 图片编码是普通的ViT的编码结果
  • 这里的文本编码不是LLM的编码,而是Q-Former的bert模块的编码结果,经过bert的Encoder的self-attention层获得文本编码输出,由于是单文本模态,不经过bert的Encoder的交叉注意力层
image = samples["image"]
text = samples["text_input"]

image_embeds = self.ln_vision(self.visual_encoder(image))  # go through ViT
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
text_tokens = self.tokenizer(
    text,
    padding="max_length",
    truncation=True,
    max_length=self.max_txt_len,
    return_tensors="pt",
).to(image.device)
text_output = self.Qformer.bert(
    text_tokens.input_ids,
    attention_mask=text_tokens.attention_mask,
    return_dict=True,
)
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)

4.2 query_output生成

接着获取用query查询图片的输出query_output,用于后面优化query和bert,需要注意:

  • query_tokens的默认大小是[1,32,hidden_size],每个图片slice对应长度为num_query_token的query向量
  • expand这行代码很关键,是把query复制了batch份,query变成[batch,32,hidden_size],batch内的每个样本是使用同一份query,每1份是共享内存的,只存在1份query;如果改变query_tokens[0],query_tokens每一个索引下内容都会变
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)  # encoder_width = ViT feature width

# query见到图片,得到一个输出,普通的交叉注意力,query真的是query,目标是优化query
# query_output的形状和query_tokens一样
query_output = self.Qformer.bert(
    query_embeds=query_tokens,
    encoder_hidden_states=image_embeds,
    encoder_attention_mask=image_atts,
    use_cache=True,
    return_dict=True,
)

# 隐向量维度统一,图片和文本的统一,映射到同一个维度embed_dim
image_feats = F.normalize(self.vision_proj(query_output.last_hidden_state), dim=-1)

上面和一开始文本输入bert不一样,Qformer.bert输出query_output过程中会经过bert的encoder的交叉注意力层:

  • 在Qformer.py的BertLayer类中可以看到,每一个layer都会先经过自注意力模块,如果是第0层,hidden_states就是query的embedding,如果在第i(i>0)层,hidden_states就是中间输出
  • 当有query_length>0时,有图像时,会计算交叉注意力,query_attention_output取query长度对应的注意力输出,encoder_hidden_states是图像编码,图像编码作为自注意力的key和value
self_attention_outputs = self.attention(
    hidden_states, attention_mask, head_mask,
    output_attentions=output_attentions,
    past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]

if query_length > 0:
    query_attention_output = attention_output[:, :query_length, :]
    if self.has_cross_attention:
        cross_attention_outputs = self.crossattention(
            query_attention_output, attention_mask, head_mask,
            encoder_hidden_states, encoder_attention_mask,
            output_attentions=output_attentions,
        )
        query_attention_output = cross_attention_outputs[0]

4.3 ITC图文对比学习

  • 对比学习会计算相似度矩阵,分别计算相似度矩阵按行和按列的相似度,在CLIP和BLIP中都有这样的操作
  • 但是在BLIP-2这里会发现做矩阵乘法张量维度是不一样的
image_feats_all = concat_all_gather(image_feats)  # [batch_size, num_query_tokens, embed_dim]
text_feat_all = concat_all_gather(text_feat)  # [batch_size, embed_dim]

sim_q2t = torch.matmul(
    image_feats.unsqueeze(1),  # [batch_size, 1, num_query_tokens, embed_dim]
    text_feat_all.unsqueeze(-1)  # [batch_size, embed_dim, 1]
).squeeze()
# 得到 [batch_size, batch_size, num_query_tokens]

# text-query similarity: [batch_size, batch_size, num_query_tokens]
sim_t2q = torch.matmul(
    text_feat.unsqueeze(1).unsqueeze(1),  # [batch_size, 1, 1, embed_dim]
    image_feats_all.permute(0, 2, 1)  # [batch_size, num_query_tokens, embed_dim]
).squeeze()

  具体而言,以query和文本相似度为例,涉及image_feats和text_feat_all的乘法,这两个张量维度不一样。
  Pytorch的matmul运算机制会自动把低维度的做维度扩展,先把右边维度对齐,扩左边,左边最前面添加1维,扩展了之后image_feats_expanded第一维大小为batch_size,所以text_feat_all广播到batch_size同样大小,像下面这样:

# 扩展维度
image_feats_expanded = image_feats.unsqueeze(1)  # 形状为 [batch_size, 1, num_query_tokens, embed_dim]
text_feat_all_expanded = text_feat_all.unsqueeze(-1).unsqueeze(0)  # 形状为 [1, batch_size, embed_dim, 1]

# 广播机制 形状为 [batch_size, batch_size, embed_dim, 1]
text_feat_all_expanded = text_feat_all_expanded.expand(batch_size, -1, -1, -1)

  有了相似度矩阵后,会计算图文/文图相似度,这里取了query对应的相似度max的元素来计算相似度。
  到这里ITC损失计算完毕,因为query和图像一起过的bert,text单独过的bert,它们没有交集,query和text不会互相看见泄露信息,如下面的代码所示:

# image-text similarity: aggregate across all query tokens
sim_i2t, _ = sim_q2t.max(-1)
sim_i2t = sim_i2t / self.temp

# text-image similarity: aggregate across all query tokens
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp  # [batch_size, batch_size*num_gpu]

rank = dist.get_rank()
bs = image.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image.device)

loss_itc = (
    F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) +
    F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
) / 2

5 第一阶段预训练loss计算——ITM图文匹配Loss

5.1 正负样本采样

BLIP-2的ITM图文匹配损失计算方式和BLIP一致,相似度矩阵对角线填-10000(对角线文本和图片是对应的,是正例),然后依概率抽负样本:

text_input_ids_world = concat_all_gather(text_tokens.input_ids)
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
image_embeds_world = all_gather_with_grad(image_embeds)

sim_t2i[:, :bs].fill_diagonal_(-10000)
sim_i2t[:, :bs].fill_diagonal_(-10000)

weights_t2i = F.softmax(sim_t2i, dim=1)
weights_i2t = F.softmax(sim_i2t, dim=1)

# select a negative image for each text
image_embeds_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_t2i[b], 1).item()
    image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

# select a negative text for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_i2t[b], 1).item()
    text_ids_neg.append(text_input_ids_world[neg_idx])
    text_atts_neg.append(text_attention_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)

  同样,和BLIP一致,会把正负样本填到列表里。但是要注意,这里输入的attention_mask_all 长度涵盖了query和text,因为函数里面后续会把query和text拼接,所以会在下面代码里面先看到这个奇怪的拼接操作:

text_ids_all = torch.cat(
    [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
)  # pos, pos, neg

text_atts_all = torch.cat(
    [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
    dim=0,
)

query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(image.device)
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)  # 注意这个mask长度包含了query和text

image_embeds_all = torch.cat(
    [image_embeds, image_embeds_neg, image_embeds], dim=0
)  # pos, neg, pos

image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
    image.device
)

5.2 ITM的loss计算

  有了正负样本,就可以计算ITM的loss,这里面有文本,有query,还有图像编码,取query对应的作为vl_embbeding计算loss:

output_itm = self.Qformer.bert(
    text_ids_all,
    query_embeds=query_tokens_itm,
    attention_mask=attention_mask_all,
    encoder_hidden_states=image_embeds_all,
    encoder_attention_mask=image_atts_all,
    return_dict=True,
)  # [batch_size, 32+seq_len, embed_dim]

vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm.size(1), :]
vl_output = self.itm_head(vl_embeddings)
logits = vl_output.mean(dim=1)

itm_labels = torch.cat(
    [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
    dim=0,
).to(image.device)

loss_itm = F.cross_entropy(logits, itm_labels)

5.3 query、text和image输入流程

  query、text和图像编码输入bert的过程有一些复杂,首先是会经过embbeding层:

query_length = query_embeds.shape[1] if query_embeds is not None else 0

# 如果同时有text和query,把query拼接在text前面,并且query的mask是1
embedding_output = self.embeddings(
    input_ids=input_ids,
    position_ids=position_ids,
    query_embeds=query_embeds,
    past_key_values_length=past_key_values_length,
)  # [batch_size, 32+seq_len, embed_dim]

  可以看到在embedding层里面做了拼接,把query拼到text前面,这也是前面要在mask里面拼上query这部分长度的原因:

if input_ids is not None:
    embeddings = self.word_embeddings(input_ids)
    if query_embeds is not None:
        embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
    embeddings = query_embeds

embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

  接着,embedding_output与图像编码encoder_hidden_states一起经过bert的encoder和之前一样

6 第一阶段预训练loss计算——ITG文本生成Loss

6.1 基本流程

  • 最后一个是文本生成损失,需要注意mask的形状,以及Qformer的输入
  • 输入了text,以及query_output的KV值(这部分有图像信息),基于图像信息解码,没有输入image_embed【Image-grounde并不是说把图片编码一起输进去】,在后面的代码展开部分会看到一个拼接操作嵌入图像信息
decoder_input_ids = text_tokens.input_ids.clone()
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
labels = decoder_input_ids.masked_fill(
    decoder_input_ids == self.tokenizer.pad_token_id, -100
)

# [batch_size, 32, hidden_size]
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)

# 用于计算prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
# mask这里需要注意,query不需要生成;
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)

# 因为传递的是query的KV,并且会把query_output的key和value拼接在text的key和value前面,后面会看到这部分
# 相当于后面text已经看到了前面的query,在这个函数里面会把query的KV拼接到text的KV前面
lm_output = self.Qformer(
    decoder_input_ids,
    attention_mask=attention_mask,
    past_key_values=query_output.past_key_values,  # 把query_output看到图像的KV,和decoder这里的KV直接拼接,嵌入图像信息
    return_dict=True,
    labels=labels,
)

loss_lm = lm_output.loss

  掩码生成部分的代码在Qformer的bert的encoder中,Qformer的forward函数的is_decoder=True,会看到掩码前面有一段全1的prefix_seq_len,后半部分是三角形的causal_mask:

if is_decoder:
    batch_size, seq_length = input_shape  # text.input_id.shape
    seq_ids = torch.arange(seq_length, device=device)
    causal_mask = (seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None])
    # add a prefix ones mask to the causal mask
    # causal and attention masks must have same type with pytorch version < 1.3
    causal_mask = causal_mask.to(attention_mask.dtype)

    if causal_mask.shape[1] < attention_mask.shape[1]:
        prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]  # 所以前面attention_mask传了32+len
        causal_mask = torch.cat([
            torch.ones((batch_size, causal_mask.shape[1], prefix_seq_len),
                       device=device, dtype=causal_mask.dtype),
            causal_mask
        ], axis=-1)

    extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :])

6.2 图像信息拼接

  ITG是Image-grounded的text generation,图片信息不是通过encoder_hidden_stats传递的,而是通过query_output的past_key_value传递的。
  前面说到,对于query_output的每一层的key和value,实际上现在会传递到bert的encoder的layer中,和text的KV拼接,实现解码过程中图片信息的嵌入。

# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (
    past_key_value[:2] if past_key_value is not None else None
)

self_attention_outputs = self.attention(
    hidden_states,
    attention_mask,
    head_mask,
    output_attentions=output_attentions,
    past_key_value=self_attn_past_key_value,  # query_output的key和value
)

  继续深入,在self-attention层可以看到, query_output的每一层的key和value,拼接在了text的key和value的前面,实现在解码时插入图像相关信息:

elif past_key_value is not None:
    key_layer = self.transpose_for_scores(self.key(hidden_states))
    value_layer = self.transpose_for_scores(self.value(hidden_states))
    key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
    value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

7 第二阶段文本生成Loss计算

  • 前一阶段的文本生成是使用Q-Former的bert模块来训练生成,这一阶段是Q-Former+LLM,在blip2_opt.py
  • 这一阶段是让Q-Former学会生成能让LLM看懂的视觉特征,只用了Q-Former的query提取图像信息这部分
# 扩展query到与图像嵌入相同的批次大小
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)

# 使用Qformer模型的BERT部分对查询token进行编码
query_output = self.Qformer.bert(
    query_embeds=query_tokens,
    encoder_hidden_states=image_embeds,
    encoder_attention_mask=image_atts,
    return_dict=True,
)

# 对查询输出进行投影转换,准备与文本部分结合
inputs_opt = self.opt_proj(query_output.last_hidden_state)  # 视觉部分编码

# 创建注意力掩码,确保所有token都能相互关注
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)

# 设置OPT模型的填充侧为右侧
self.opt_tokenizer.padding_side = "right"

# 处理样本中的文本输入,每个文本后添加换行符
text = [t + "\n" for t in samples["text_input"]]

# 对文本进行编码
opt_tokens = self.opt_tokenizer(
    text,
    return_tensors="pt",
    padding="longest",
    truncation=True,
    max_length=self.max_txt_len,
).to(image.device)

# 创建目标张量,将填充token的位置标记为-100,这样在计算交叉熵时不考虑这些位置
targets = opt_tokens.input_ids.masked_fill(
    opt_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
)

# 如果存在prompt,则不对其应用损失
if self.prompt:  # prompt不需要生成,例如 'a photo of'
    targets[:, :self.prompt_length] = -100  # 不对提示词部分计算损失

  在这一阶段给LLM的输入是视觉和文本编码进行拼接,target是文本caption,所以视觉部分是empty的target,和prompt一样填充-100无需算到cross_entropy里面:

empty_targets = (
    torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)
)
targets = torch.cat([empty_targets, targets], dim=1)

inputs_embeds = self.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)  # 拼接视觉和文本
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

with self.maybe_autocast():
    outputs = self.opt_model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        return_dict=True,
        labels=targets,
    )
loss = outputs.loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值