VILT算法解读

VILT是一种典型的单塔结构,不同于双塔结构由两个独立的Image Encoder以及Text Encoder组成(比如clip),单塔结构的模型一般只有一个共用的编码器,称为Multi-Modal Encoder。 

1、VILT算法原理
    VILT被认为是最简单的单塔结构,它用统一的方式处理视觉和文本两种模态的输入,利用Transformer结构来作为共用的编码器学习模态间的信息交互(多模态融合)。其具体结构如下所示:


(1)数据输入处理

    既然是图文多模态,输入数据必然包含成对的图像和文本(图文pair)。

首先是文本输入,原始的文本输入可以看成是一个单词序列(比如图中的a,stone,statue等等),然后通过Word Embedding算法处理为词嵌入向量,所谓词嵌入向量就是一个高维的向量表征,同一个单词具有相同的嵌入表征,同时同义词的表征距离越近(具体词嵌入的过程可以参考NLP的一些资料,简单来说就是需要一个词汇表(vocabulary dict),然后将单词分词后转化为词汇表中的索引,再将数字索引变为高维向量)。注意到除了词嵌入向量外,每个单词最后的嵌入向量还添加了位置嵌入向量和模态类别嵌入向量(三种向量进行concat)。

然后是图像输入,与文本不同,原始的图像需要分块处理,变成一个图像块的序列,然后每个图像块会通过一个Linear Projection层投影成一个高维视觉特征图,然后将二维特征图展开变成一位特征向量,这样每个图像块就都变成了一个视觉特征向量,同时特征维度和词嵌入向量的特征维度一致。同时每个图像块除了视觉特征向量外,也会添加位置嵌入向量和模态类别嵌入向量(三种向量进行concat)。

    此外,两种模态分别都嵌入了一个额外的可学习[class] embedding,方便和下游任务对接。通过对两种输入的处理,最后不管是文本单词,还是图像块,都被embedding为了同样维度的token特征。

(2)模型结构

    由于视觉和文本输入被转化为了相同的token特征,所以后续的处理方式就可以统一了。因此VILT仅通过一个Transformer的结构就可以进行建模,它使用了一个预训练的ViT来初始化这个Transformer。

(3)预训练任务(损失函数)

    VILT在预训练阶段主要使用了两种较典型的优化目标,分别为Image Text Matching(ITM)和Masked Language Modeling(MLM)。

Image Text Matching(ITM)

    图像-文本匹配(ITM)可以预测一对图像和文本是正例(匹配)还是负例(不匹配)。使用多模态编码器输出嵌入的[CLS] token作为图像-文本对的联合表示,并附加一个全连接(FC)层,最后使用Softmax来预测一个二分类概率


    其中,y是一个表示ground truth标签的二维one-hot向量。

    ITM在使用时候的关键在于如何构造正负例样本,ITM 本质上是二分类,正样本就是匹配的图文对,负样本需要在训练的过程中进行采样,对每个图像一般在数据集中随机采样一个不匹配的文本作为负样本;对每个文本在数据集中随机采样1个不匹配的图像作为负样本。

Masked Language Modeling(MLM)

    MLM是训练NLP中Bert模型所使用的预训练任务,通常屏蔽(或随机替换)给定句子中特定百分比(15%)的单词,模型期望基于该句子中的其他单词预测这些被屏蔽的单词。这样的训练方案使这个模型在本质上是双向的,因为掩蔽词的表示是根据出现的词来学习的,不管是左还是右。你也可以把它想象成一个填空式的问题陈述。


    在单塔多模态中的MLM与Bert中的稍有不同,主要区别还是在输入信息,VILT中的MLM在预测被MASK的单词时,除了依赖上下文还可以依赖视觉模态的信息。

    另外,VILT在使用MLM时还结合了whole word masking技巧。所谓whole word masking是将连续的子词tokens进行mask的技巧,因为大部分的分词算法可能会将一个单词分为多个文本token,因此普通的MLM不能做到基于单词粒度的MASK,只能基于文本token这个粒度进行MASK。

2、VILT适合的下游任务
    VILT这种单塔的多模态结构由于在预训练的时候进行了充分的模态交互(模态融合),因此适合下游诸如VQA(视觉问答),NLVR(自然语言视觉推理 )等等。

(1)VILT进行VQA

    VQA任务就是根据一张图片来提问,希望模型能正确回答对应的问题。比如下面有一张图片,图片里有两只猫,问题Question是“How many cats are there?”,希望模型给出的答案Answer是“2”


(2)VILT进行NLVR

    VILT还可以用来进行NLVR任务,NLVR判断两张图片和一条文本的语义是否匹配,比如准备两张图片,对应的文本内容是text,如果图片和文本内容匹配则模型输出True,否则输出False。


3、小结
     VILT作为多模态单塔(单流)结构中的代表工作,虽然其在检索上的效率偏低,但是其在下游许多需要强交互的多模态理解任务上表现出了比双塔结构更强劲的性能!

### ViLT 模型架构详解 ViLT(Vision-and-Language Transformer)是一种基于Transformer的多模态模型,旨在统一处理视觉和语言任务。以下是其主要组成部分及其功能: #### 1. **轻量级视觉输入嵌入** 为了减少计算复杂度并提高效率,ViLT采用了一种简单的线性投影方法来嵌入图像补丁(patches)。这种方法避免了传统卷积神经网络(CNNs)作为视觉特征提取器的需求,从而显著降低了计算开销[^4]。 #### 2. **拼接后的向量表示** 在ViLT中,文本和图像被分别编码成固定长度的向量序列,并通过一种特殊的方式进行拼接。具体来说,经过预处理的文本和图像数据会被转换为具有相同维度的向量形式,最终形成形状为 `[batch_size, text_max_length + 145, 768]` 的张量[^5]。这种设计使得两种不同类型的输入能够无缝融合到同一个Transformer框架下。 #### 3. **Encoder 层的设计** ViLT 的 Encoder 部分遵循 Vision Transformers (ViT) 的经典布局,主要包括以下几个子模块: - **自注意力机制**: 实现于 `ViltSelfAttention` 和 `ViltSelfOutput` 中,负责捕捉全局上下文中各个 token 之间的关系。 - **交叉注意机制**: 第一次黄色加绿色区域对应于此阶段 (`ViltAttention`) ,允许视觉与文字之间相互作用。 - **前馈网络(FFN)**: 黄色加上蓝色部分构成了标准 MLP 结构(`ViltMLP`) ,进一步增强了表达能力。 值得注意的是,由于引入了专门定制化的 Layer ——即所谓的 “VlitLayer”, 用户可以根据实际需求灵活调整堆叠层数以适应不同的应用场景. #### 4. **训练技巧的影响分析** 从实验结果来看,某些常见的 NLP 或 CV 技巧对提升 ViLT 性能的效果并不一致: - Whole Word Masking(WWM): 对整体表现有所改善但幅度较小. - Masked Patch Prediction(MPP): 尝试利用该目标函数改进图片理解效果不佳因此未继续沿用 . - Random Augmentation(RandAugment): 显示出了较为明显的正面贡献.[^3] 综上所述,VILT 不仅继承了 transformer 在单一领域内的强大优势还创造性地解决了跨媒体间协作难题. ```python class ViltModel(nn.Module): def __init__(self, config): super().__init__() self.text_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.patch_embeddings = nn.Linear(config.image_patch_dim, config.hidden_size) encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers) def forward(self, input_ids, pixel_values): # Text embeddings text_embeds = self.text_embeddings(input_ids) # Image patch embeddings batch_size, _, height, width = pixel_values.shape patches = rearrange(pixel_values, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16) image_embeds = self.patch_embeddings(patches) # Concatenate and encode combined_embeds = torch.cat([text_embeds, image_embeds], dim=1) output = self.encoder(combined_embeds.permute(1, 0, 2)) return output.permute(1, 0, 2) ``` 上述代码片段展示了如何构建一个基本版本的 ViLT 模型,其中包含了文本和图像的嵌入过程以及后续的 Transformer 编码操作。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值