理论核心
- 冻结视觉编码器,训练Q-former,学习视觉-文本的表征
- 冻结LLM,训练视觉到文本的生成
模型架构
Q-Former consists of two transformer submodules that share the same self-attention layers
- 带有视觉编码器的image Transformer,来提取视觉特征
- a text transformer 同时充当 a text encoder 和 text decoder
- 可学习query embedding
- 首先和shared self- attention交互,可以学习到text信息
- 然后通过cross attention与image encoder交互,学习图像的信息
Image Transformer输出的是学习到的query embedding 也就是 Z Z Z,Text Transformer学习到的是文本表征,也就是 t t t,就剩【CLS】token的embedding
Since Z contains multiple output embeddings (one from each query), we first compute the
pairwise similarity between each query output and t, and then select the highest one as the image-text similarity
query的具体组成:
好好好,居然在撩我!!!
Q-Former核心
Image-text Contrastive
# 将image_feats扩展到所有GPU上,image_feats是图像特征,是一个tensor,维度为[batch_size*num_gpu, num_query_tokens, embed_dim]
image_feats_all = concat_all_gather(image_feats)
# 将text_feat扩展到所有GPU上,text_feat是文本特征,是一个tensor,维度为[batch_size*num_gpu, embed_dim]
text_feat_all = concat_all_gather(text_feat)
# 对每个查询标记计算图像到文本的相似度
sim_q2t = torch.matmul(
image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
).squeeze()
# 对image-text相似度进行聚合,取所有查询标记的最大值
sim_i2t, _ = sim_q2t.max(-1)
sim_i2t = sim_i2t / self.temp # 对相似度进行缩放
# 对每个查询标记计算文本到图像的相似度
sim_t2q = torch.matmul(
text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
).squeeze()
# 对text-image相似度进行聚合,取所有查询标记的最大值
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp # 对相似度进行缩放
# 获取进程的rank和batch的大小
rank = dist.get_rank()
bs = image.size(0)
# 生成targets,用于计算损失
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
image.device
)
# 如果samples中包含image_id,则表示在COCO检索微调训练中
if "image_id" in samples.keys():
# 获取image_ids
image_ids = samples["image_id"].view(-1,1)
# 将所有图片的image_id扩展到所有GPU上
image_ids_all = concat_all_gather(image_ids)
# 计算相似度目标,对匹配的图像进行惩罚
pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)
# 计算损失,在COCO检索微调训练中
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()
loss_itc = (loss_t2i+loss_i2t)/2
# 如果不是COCO检索微调训练
else:
# 普通的图像-文本对比损失计算
loss_itc = (
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
) / 2
Text-Image Match
# 将text_tokens中的输入ID和注意力掩码扩展到所有GPU上
text_input_ids_world = concat_all_gather(text_tokens.input_ids)
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
# 将image_embeds扩展到所有GPU上,并且允许梯度传播
image_embeds_world = all_gather_with_grad(image_embeds)
# 禁止梯度传播,用于下面的操作
with torch.no_grad():
# 如果samples中包含image_id,则在image_ids与image_ids_all.t()相等的位置填充-10000
if "image_id" in samples.keys():
mask = torch.eq(image_ids, image_ids_all.t())
sim_t2i.masked_fill_(mask, -10000)
sim_i2t.masked_fill_(mask, -10000)
else:
# 否则在sim_t2i和sim_i2t的对角线位置填充-10000
sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
# 计算softmax权重,对文本到图像和图像到文本的相似性进行归一化
weights_t2i = F.softmax(sim_t2i, dim=1)
weights_i2t = F.softmax(sim_i2t, dim=1)
# 为每个文本选择一个负向图像
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# 为每个图像选择一个负向文本
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(text_input_ids_world[neg_idx])
text_atts_neg.append(text_attention_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)
# 将所有的正向文本和负向文本组合成一个tensor
text_ids_all = torch.cat(
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
) # pos, pos, neg
text_atts_all = torch.cat(
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
dim=0,
)
# 生成用于查询任务的token和注意力掩码
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
image.device
)
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
# 将所有的正向图像和负向图像组合成一个tensor
image_embeds_all = torch.cat(
[image_embeds, image_embeds_neg, image_embeds], dim=0
) # pos, neg, pos
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
image.device
)
# 执行查询任务,获取输出
output_itm = self.Qformer.bert(
text_ids_all,
query_embeds=query_tokens_itm,
attention_mask=attention_mask_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=True,
)
# 提取视觉-语言交互层的输出
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
# 对视觉-语言交互层的输出应用itm头
vl_output = self.itm_head(vl_embeddings)
logits = vl_output.mean(dim=1)
# 生成itm标签,对正向样本标记为1,对负向样本标记为0
itm_labels = torch.cat(
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
dim=0,
).to(image.device)
# 计算itm任务的交叉熵损失
loss_itm = F.cross_entropy(logits, itm_labels)
Image Caption
# 克隆text_tokens的输入ID,并将第一个位置设置为起始符号的ID
decoder_input_ids = text_tokens.input_ids.clone()
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
# 将要预测的标签设置为decoder_input_ids,并使用-100填充pad位置
labels = decoder_input_ids.masked_fill(
decoder_input_ids == self.tokenizer.pad_token_id, -100
)
# 生成注意力掩码,用于语言模型的输入
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
image.device
)
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
# 执行语言模型的训练,获取LM输出
lm_output = self.Qformer(
decoder_input_ids,
attention_mask=attention_mask,
past_key_values=query_output.past_key_values,
return_dict=True,
labels=labels,
)
# 计算语言模型的损失
loss_lm = lm_output.loss
# 返回总的损失和各个部分的损失
return BlipOutput(
loss=loss_itc + loss_itm + loss_lm,
loss_itc=loss_itc,
loss_itm=loss_itm,
loss_lm=loss_lm,
)
Generate
def generate(
self,
samples,
use_nucleus_sampling=False,
num_beams=3,
max_length=30,
min_length=10,
top_p=0.9,
repetition_penalty=1.0,
):
"""
Args:
samples (dict): 包含以下键的字典:
- image (torch.Tensor): 形状为(batch_size, 3, H, W)的张量
use_nucleus_sampling (bool): 是否使用核采样。如果为False,则使用top-k采样。
num_beams (int): 用于束搜索的束的数量。1表示不使用束搜索。
max_length (int): 要生成的序列的最大长度。
min_length (int): 要生成的序列的最小长度。
top_p (float): 核采样的累积概率。
repetition_penalty (float): 重复惩罚的参数。1.0表示没有惩罚。
num_captions (int): 每个图像要生成的字幕数。
Returns:
captions (list): 长度为batch_size * num_captions的字符串列表。
"""
# 获取图像并将其编码为图像嵌入
image = samples["image"]
image_embeds = self.ln_vision(self.visual_encoder(image))
# 如果不使用核采样,则扩展图像嵌入,用于束搜索
if not use_nucleus_sampling:
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
else:
num_beams = 1
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
# 设置模型参数
model_kwargs = {
"encoder_hidden_states": image_embeds,
"encoder_attention_mask": image_atts,
}
# 生成文本描述
input_ids = (
torch.LongTensor(image.size(0), 1)
.fill_(self.tokenizer.bos_token_id)
.to(image.device)
)
# image token编码
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
# 调用Qformer进行生成
outputs = self.Qformer.generate(
input_ids=input_ids,
query_embeds=query_tokens,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
do_sample=use_nucleus_sampling,
top_p=top_p,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
**model_kwargs
)
# 解码生成captions
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
qformer的generate模块好复杂
generate 函数用于模型的条件生成。
首先,generate 函数通过各种参数调用生成。这些参数包括生成配置文件、标记处理列表、停止条件列表、前缀允许令牌 fn、委托人模型、流媒体设备、负提示 ids 和注意力掩码等。这些参数用来进一步控制生成的方式和结果。
generate 函数主要分为以下步骤:
-
处理
generation_config
和可能更新它的 kwargs,以及验证.generate()
调用。 -
设置生成参数(如处理?)
-
定义模型输入
-
定义其他模型kwargs
-
准备自回归生成的
input_ids
-
准备包含其他停止标准的
max_length
-
确定生成模式
有了这些步骤,generate
函数可以进入不同的生成模式和执行相应的生成方法,比如贪婪搜索,显示搜索等。
在选择相应的生成模式后,generate
函数根据模型的生成配置、输入和相应的参数调用相应的生成方法,包括模型的贪婪搜索、显示搜索、样本生成等。
接下来,通过选择合适的方法和参数对模型进行生成,并返回生成的输出。
最后,如果选择了与专辑模式,则可以使用 assistant_model 对象来进行助攻生成。在这种情况下,生成方法中的一些参数和逻辑都将有所不同。目的是加速生成。具体来说,在助攻生成过程中,特定的助攻模型将返回模型的生成 output 或torch.FloatTensor
。
总的来说,generate
函数负责执行不同的生成方法和逻辑以生成模型的输出。它允许用户根据实际需要执行不同的生成方法,并支持其他参数的进一步控制。这个函数给了人们灵活的选择,以获得满足需求的生成输出。
BLIP2实战
Image2text
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
device = "cuda:0"
inputs = processor(image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
text+img2text
- 注意:通过提供文本提示来扩展图像字幕生成,模型将在给定图像的情况下接着提示词往下补充
prompt = "this is a cartoon of"
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
VQA
- 用于视觉问答时,提示必须遵循特定格式: “Question: {} Answer:”
prompt = "Question: What is a dinosaur holding? Answer:"
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
contextQA
context = [
("What is a dinosaur holding?", "a torch"),
("Where are they?", "In the woods.")
]
question = "What for?"
template = "Question: {} Answer: {}."
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question + " Answer:"
print(prompt)
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
- https://huggingface.co/blog/zh/blip-2#%E5%9B%BE%E5%83%8F%E5%AD%97%E5%B9%95%E7%94%9F%E6%88%90