一、前言
论文:Minigpt-4: Enhancing vision-language understanding with advanced large language models.
Github:Minigpt-4
MiniGPT-4 是一个高效的多模态模型,它通过整合 BLIP-2 的视觉组件和Vicuna 的语言模型,并在视觉和语言之间加入单层投影层实现对齐。模型采用了冻结策略,仅微调投影层,并通过高质量数据进一步提升对话能力。这使得 MiniGPT-4 能够在较少计算资源的条件下,展现出接近 GPT-4 的视觉对话能力。
二、总体架构
MiniGPT-4 认为,当前新的语言模型和视觉编码器更新速度非常快,采用冻结ViT & Q-Former 和语言模型参数,只训练连接ViT & Q-Former 和语言模型的投影层的方式会非常有效,避免了重新训练视觉部分所需的大量计算资源。
三、技术细节
3.1 模型结构与初始化:Vision Encoder 和 Q-Former
在模型初始化阶段,init_vision_encoder 和 init_Qformer 两个方法分别用于视觉编码器和Q-Former 的构建。
3.1.1 视觉编码器初始化:
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
):
logging.info('Loading VIT')
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
if not freeze:
precision = "fp32" # fp16 is not for training
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
if freeze:
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
for name, param in ln_vision.named_parameters():
param.requires_grad = False
ln_vision = ln_vision.eval()
ln_vision.train = disabled_train
logging.info("freeze vision encoder")
logging.info('Loading VIT Done')
return visual_encoder, ln_vision
- 视觉编码器选择:使用 EVA-CLIP (ViT-G) 作为图像特征提取器,将输入图像编码为视觉嵌入。
- 层归一化 (LayerNorm):对编码后的嵌入进行归一化,确保特征稳定性。
3.1.2 Q-Former 的初始化:
def init_Qformer(cls, num_query_token, vision_width, freeze):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = 2
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
if freeze:
for name, param in Qformer.named_parameters():
param.requires_grad = False
Qformer = Qformer.eval()
Qformer.train = disabled_train
query_tokens.requires_grad = False
logging.info("freeze Qformer")
return Qformer, query_tokens
- Q-Former 的设计:基于 BERT 的结构进行初始化,支持交叉注意力层(Cross-Attention Layer),与视觉嵌入进行交互。
- 查询向量 (Query Tokens):这些参数化的查询向量在训练时会从视觉编码中提取关键信息。
3.2 图像编码与视觉-语言融合
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
- 视觉编码:将图像通过视觉编码器生成嵌入,并通过 LayerNorm 归一化处理。
- 查询向量与图像嵌入交互:使用 Q-Former 模块通过交叉注意力,从图像嵌入中提取与查询向量相关的特征。
- 线性投影:通过 llama_proj 将 Q-Former 的输出特征投影到 LLM 的嵌入维度,以确保与语言模型的输入对齐。
3.3 语言模型输入处理与前向传播
def preparing_embedding(self, samples):
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
else:
if "instruction_input" in samples:
instruction = samples["instruction_input"]
elif self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = None
if hasattr(self, 'chat_template') and self.chat_template:
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction='mean'):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = cond_atts[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
四、一些问题总结(持续更新)
- 在 Q-Former 之后添加线性投影层有什么原因?
- 维度对齐:确保 Q-Former 的输出与 LLM 的输入维度匹配。
- 语义对齐:帮助视觉嵌入与语言模型的语义空间融合。
- 计算高效:通过简单的线性变换实现有效的特征映射。
- 模块化设计:增强了模型的灵活性,适应不同的视觉和语言模型组合。
参考
- Zhu, Deyao, et al. “Minigpt-4: Enhancing vision-language understanding with advanced large language models.” arXiv preprint arXiv:2304.10592 (2023).