BLIP: Bootstrapping Language-Image Pre-training,自举语言图像预训练模型
本次分享对BLIP做简要分析,希望对大家有所帮助。
论文地址:https://arxiv.org/pdf/2201.12086.pdf
源代码: https://github.com/salesforce/BLIP
1 贡献点
1.1 MED
MED(Multimodal mixture of Encoder-Decoder)一种编码器-解码器混合架构 ,MED 的特点是很灵活,它既可以作为单模态的编码器,又可以作为基于图像的文本编码器。
1.2 CapFilt
CapFilt(Captioning and Filtering)一种高效率利用噪声网络数据的方法。
2 介绍
2.1 MED
以往的预训练模型仅在基于理解的任务或者基于生成的任务方面表现出色,很少有可以兼顾的模型。比如,基于编码器的模型,像 CLIP,ALBEF 不能直接转移到文本生成任务 (比如图像字幕),而基于编码器-解码器的模型,像 SimVLM 不能直接用于图像文本检索任务。
BLIP 的模型架构 MED由三个视觉语言目标联合训练:
图像文本的对比学习: 下图前两列。最左边的是图像编码器,将输入图像分割成一个个的 Patch 并将它们编码为一系列 Image Embedding。第2列的是文本编码器,其中 [CLS] token 附加到文本输入的开头以总结句子,提取文本特征Text Embedding。使用ITC(Image-Text Contrastive Loss)损失,目标是对齐视觉和文本的特征空间。方法是使得正样本图文对的相似性更大,负样本图文对的相似性更低。
图像文本匹配:下图第三列。视觉文本编码器,使用 Cross-Attention,且注意力部分是双向的 Self-Attention。添加一个额外的 [Encode] token,作为图像文本的联合表征。使用ITM(Image-Text Matching Loss)损失,目标是学习图像文本的联合特征,捕获视觉和语言之间的特征对齐,以区分图像-文本对是正样本还是负样本。
图像条件语言建模:下图最后一列。视觉文本解码器,使用 Cross-Attention,且注意力部分是 Casual-Attention,目标是预测下一个 token。添加一个额外的 [Decode] token 和结束 token,作为生成结果的起点和终点。使用LM((Language Modeling Loss)作用于第1列视觉编码器和第4列视觉文本编码器,目标是根据给定的图像生成图像的文本描述。
2.2 CapFilt
BLIP 这里提出了一种高效率利用噪声网络数据的方法:CapFilt(Captioning and Filtering)。如下图所示,它包含两个模块:
字幕器 Captioner:给一张网络图文对中的图片,文本部分为下图粉色框部分。通过Cap生成字幕下图绿色框部分。
过滤器 Filter:过滤掉噪声图文对。它是一个视觉文本编码器,看文本是否与图像匹配。 下图Filt预测两个文本和图像匹配情况,判断文本是否有噪声。
最后,将过滤后的图像-文本对与人工注释对相结合,形成一个新的数据集,作者用它来预训练一个新的模型。
详细流程如下:
其中 、
: 带有噪音的网络图文对中的图像、文本,
、
: 人工标注图文对中的图像、文本,
红色 、绿色
: 经Filtering过滤前、后的文本,
首先,基于 、
、
、
训练一个预训练模型,如下图左侧。使用
、
训练Filter和Captioner;
然后,基于Captioner看图说话模型,输入 ,得到
对应的红色
;
接着,基于Filter过滤器,输入带有噪音的( ,红色
)和(
,红色
),过滤后得到(
,绿色
)(
,绿色
)];
最后,将[( ,绿色
)(
,绿色
)(
,
)]作为新数据集,重新训练模型。
3 实验
联合使用 Captioner 和 Filter 可以观察到性能改进,而且它们的效果相互互补,证明了 CapFilt 方法能够从嘈杂的原始数据中提炼出有用的数据。
与现有方法相比,BLIP 实现了显着的性能提升。
4 代码Image-Text Captioning:
import os
from PIL import Image
from transformers import pipeline
blip_pipe = pipeline("image-to-text", model="blip/blip-image-captioning-large/", device="cuda:0", max_new_tokens=70)
init_image = Image.open('test.jpg').convert("RGB")
prompt = blip_pipe(init_image)[0]["generated_text"]
print('prompt')