ALBEF算法解读

ALBEF论文介绍了一种在视觉和语言表示学习中应用动量蒸馏的技术,通过对比学习(ITC)、匹配学习(ITM)和语言建模(MLM)改进多模态模型的性能。论文强调了模态融合的重要性,特别是视觉模型的规模,并提出了动量模型来处理噪声数据和困难的负样本。
摘要由CSDN通过智能技术生成

ALBEF论文全名Align before Fuse: Vision and Language Representation Learning with Momentum Distillation,来自于Align before Fuse,作者团队为Salesforce Research。
论文地址:https://arxiv.org/pdf/2107.07651.pdf
论文代码:GitCode - 开发者的代码家园

1、灵感来源

a)图 (a) 是VSE架构,它们的文本端就是直接抽一个文本特征,但是它们的视觉端非常大,需要的计算量非常多,因为它通常是一个目标检测器。当得到了文本特征和视觉特征之后,它最后只能做一个很简单的模态之间的交互,从而去做多模态的任务。

b)图(b)是CLIP的结构,视觉端和文本端都用同等复杂度的encoder进行特征提取,再做一个简单的模态交互,结构优点是对检索任务极其有效,因为它可以提前把特征都抽好,接下来直接算Similarity矩阵乘法就可以,极其适合大规模的图像文本的检索,非常具有商业价值。缺点是只计算Cosine Similarity无法做多模态之间深度融合,难一些的任务性能差。

c)图(c)是Oscar或者ViLBERT、Uniter采用的架构,因为对于多模态的任务,最后的模态之间的交互非常重要,只有有了模态之间更深层的交互,VQA、VR、VE这些任务效果才会非常好,所以他们就把最初的简单的点乘的模态之间的交互,变成了一个Transformer的Encoder,或者变成别的更复杂的模型结构去做模态之间的交互,所以这些方法的性能都非常的好,但是随之而来的缺点也很明显:所有的这一系列的工作都用了预训练的目标检测器,再加上这么一个更大的模态融合的部分,模型不论是训练还是部署都非常的困难。

d)图 (d) 是ViLT的架构。当Vision Transformer出来之后,ViLT这篇论文就应运而生了,因为在Vision Transformer里,基于Patch的视觉特征与基于Bounding Box的视觉特征没有太大的区别,它也能做图片分类或者目标检测的任务,因此就可以将这么大的预训练好的目标检测器换成一层Patch Embedding就能去抽取视觉的特征,所以大大的降低了运算复杂度,尤其是在做推理的时候。但是如果文本特征只是简单Tokenize,视觉特征也只是简单的Patch Embedding是远远不够的,所以对于多模态任务,后面的模态融合非常关键,所以ViLT就直接借鉴 c类里的模态融合的方法,用一个很大的Transformer Encoder去做模态融合,从而达到了还不错的效果。因为移除了预训练的目标检测器,换成了可学习的Patch Embedding Layer。

ViLT的优点:模型极其简单。它虽然是一个多模态学习的框架,但跟NLP的框架没什么区别,就是先Tokenized,然后送到一个Transformer去学习,所以非常的简单易学。
ViLT的缺点:1)它的性能不够高,ViLT在很多任务上是比不过 © 类里的这些方法的,原因是对于现有的多模态任务,需要更多的视觉能力(可能是由于数据集的bias),因此视觉模型需要比文本模型要大,最后的效果才能好,但是在ViLT里,文本端用的Tokenizer很好,但是Visual Embedding是Random Initialized,所以它的效果就很差2)ViLT虽然推理时间很快,但它的训练时间非常慢,在非常标准的一个4 million的数据集set上,ViLT需要64张32G的GPU训练三天,它训练的复杂度和训练的时间丝毫不亚于 © 类的方法,所以它只是结构上简化了多模态学习,但训练难度并没有降低。

从上面可以发现,a和b的模态融合是比较弱的,比如b的代表就是clip,他是使用余弦相似度计算文本和图像的相似性,这对于VQA,VE之类的任务效果就比较差。c的代表是ViLBERT,除了多模态融合使用较大模型外,在图像端也使用了较大的模型。d的代表是ViLT,只在多模态融合时使用较大模型。

那么我们发现,b类模型的模态融合方面较弱,适合做检索但不适合做推理问答;cd类模型在模态融合方面较强,适合做推理问答但不适合做检索。

因此,在模型的选择上:

        1、因为有图像的输入和文本的输入,模型应有两个分支分别抽取图像文本特征。
        2、在多模态学习里,视觉特征重要性远远要大于这个文本特征,所以应该使用更大更强的视觉模型。
        3、多模态学习模态之间的融合也非常关键,因此需要模态融合的模型尽可能大,所以好的多模态学习网络结构应该像 ©,也就是文本编码器比图像编码器小,多模态融合的部分尽可能大。

2、ALBEF预训练

ALBEF有三个损失函数:图像文本对比损失(ITC)、图像文本匹配损失(ITM)和掩码语言建模损失(MLM)。

图像和文本对在经过图像编码器和文本编码器后会算一个ITC,促使学习特征对齐的能力。对齐后的文本和图像特征再经过多模态编码器会算一个ITM,促使学习特征交互的能力。仿照BERT,文本mask后再单独过一遍文本编码器和多模态编码器来算MLM,促使学习语言和特征交互的能力。

另外,还有两个训练技巧:

        1、使用ITC相似度矩阵,挖掘困难的负样本,来提升ITM的难度,进而督促模型学习。

        2、由于数据中有噪声(有的负文本对更能表述图像),仿照MoCo,使用动量模型作为teacher模型,在学习的过程中student模型以teacher的软标签作为label进行学习,能够抵抗数据质量不高和一图多义的问题。

图像端:给定任何一张图片,按照Vision Transformer的做法,把它打成patch,然后通过patch embedding layer送给一个Vision Transformer。这里是一个非常标准的12层的Vision Transformer的base模型,如果图片是224x224,那这里的sequence length就是196,然后加上额外的一个CLS token就是197,它的特征维度是768,所以这里绿黄色的特征就是197乘以768。但论文在预训练阶段用的图片是256x256,所以这里绿色的sequence length就会相应的再长一些。它的预训练参数用的DEiT,也就是Data Efficient Vision Transformer在ImageNet 1K数据集上训练出来的初始化参数。
文本端:文本端为保持计算量与clip类似,并且增强模态融合的部分,用前六层做文本编码,剩下的六层transformer encoder作为multi-model fusion的过程。文本模型用BERT模型做初始化,它中间的特征维度也是768,它也有一个CLS token代表了整个句子的文本信息。
momentum model:ALBEF为了做momentum distillation,而且为了给ITC loss提供更多negative,增加了momentum model,这个模型参数由左边训练的模型参数通过moving average得到的(和MoCo一模一样),通过把moving average的参数设的非常高(论文里是0.995)来保证momentum model不会那么快更新,产生的特征更加稳定,不仅可以做更稳定的negative sample,而且还可以去做momentum distillation。

ITC

ITC目的是在融合前学习更好的单模态表示,并且对齐图像和文本特征。它学习一个相似度函数,使得相关联的图像-文本对具有更高的相似度得分。受MoCo的启发,维护了两个队列来存储来自动量单模态编码器的最新M个图像-文本表示。

ITC loss:对比学习就是首先定义一个正样本对,然后定义很多负样本对,对比使正样本对之间的距离越来越近,正负样本对之间的距离越来越远。在ALBEF里首先将图像I通过vision transformer之后得到图像的全局特征,图中黄色CLS token作为全局特征,也就是一个768×1的向量,文本这边先做tokenize,将文本text变成tokenization的序列,再输入BERT的前六层,得到了一系列的特征,文本端的CLS token作为文本的全局特征,也是一个768×1的向量。接下来与MoCo相同,图像的特征向量先做一下downsample和normalization,将768×1变成了256×1的向量,文本特征向量也是768变成256×1。正样本这两个特征距离尽可能的近,它的负样本全都存在一个q里,含有65536个负样本(因为它由momentum model产生的,没有gradient,所以它并不占很多内存),正负样本之间的对比学习,使得这两个特征距离尽可能的远。这个过程就是align before fuse的align,也就是说在图像特征和这个文本特征输入Multi-model Fusion的encoder之前,就已经通过对比学习的ITC loss让这个图像特征和文本特征尽可能的拉近,在同一个embedding space里,具体使用了cross entropy loss。

ITM

ITM预测一个图文对是正对(匹配)还是负对(不匹配)。我们使用多模态编码器的[CLS] token的输出嵌入作为图像-文本对的联合表示,并附加一个全连接(FC)层,然后是softmax来预测两类概率。

本文提出一种策略,为ITM任务采样较难的负样本,而且计算开销为零。如果图像-文本具有相似的语义,但在细节上存在差异,这就是一个比较难的图像-文本对。我们使用公式1中的对比相似性来找到批量内的困难的负对。对于小批量中的每个图像,我们按照对比相似性分布从同一批次中采样一个负文本,其中与图像更相似的文本有更高的机会被采样。同样,我们还为每个文本采样一个难的负图像。

ITM loss:Image Text Matching,属于一个二分类任务,就是给定一个图片,给定一个文本,图像文本通过ALBEF的模型之后输出一个特征,在这个特征之后加一个分类头,也就是一个FC层,然后去判断I和T是不是一个对,这个loss虽然很合理,但是实际操作的时候发现这个loss太简单,所以这个分类任务,很快它的准确度就提升得很高无法进一步优化。因此ALBEF通过某种方式选择最难的负样本(最接近于正样本的那个负样本),具体来说ALBEF的batch size是512,所以ITM loss正样本对就是512个,对于mini batch里的每一张图像,把这张图片和同一个batch里所有的文本都算一遍cos similarity,然后它在这里选择除了它自己之外相似度最高的文本当做negative,这样ITM loss就非常有难度。

MLM

MLM利用图像和对应的文本来预测被遮盖的token。我们以15%的概率随机遮盖输入token,并用特殊标记[mask]替换它们。

MLM Loss:Mask Language Modeling,它把原来完整的句子text T变成一个T’,也就是有些单词被mask掉了,然后它把缺失的句子和图片一起通过ALBEF的模型,最后把之前的完整的句子给预测出来,它这里也借助了图像的信息去更好的恢复被mask掉的单词。

momentum distillation动量蒸馏

用于预训练的图像-文本对大部分是从web上收集的,它们往往有噪声。正相关对通常是弱相关的:文本可能包含与图像无关的单词,或者图像可能包含文本中没有描述的实体。对于ITC学习,图像的负面文本也可能与图像内容相匹配。对于MLM,可能存在与标注不同的其他单词,这些单词同样好地描述了图像(或更好)。然而,ITC和MLM的one-hot标签会惩罚所有负面预测,无论其正确性如何

为解决这个问题,本文建议从动量模型生成的伪目标中学习。动量模型是一个不断进化的老师,由单模态和多模态编码器的指数移动平均版本组成。在训练过程中,我们训练基础模型,使其预测与动量模型的预测相匹配。

具体的,作者认为one hot label(就是图片和文本就是一对,其他跟它都不是一对)对于ITC和MLM这两个loss来说不好,因为有的负样本也包含了很多的信息。所以作者采取了自训练方式,先构建一个momentum model然后用这个动量模型去生成pseudo targets伪目标(其实就是一个softmax score),这样它就不再是一个one hot label。
动量模型在已有模型之上做exponential moving average EMA,目的是在原始模型训练的时候,不仅希望模型预测与ground truth的one hot label去尽可能的接近,还希望模型预测与动量模型出来的pseudo targets尽可能的匹配,这样就能达到一个比较好的折中点。因为当one hot label正确时,可以学习到很多信息,但当one hot label是错误的,或者是noisy的时候,作者希望稳定的momentum model能够提供一些改进。
以ITC loss为例,它是基于这个one hot label的,所以这里再算一个pseudo target loss去弥补它的一些缺陷和不足。这个loss跟前面equation1里的ITC loss的不同,这里将ground truth换成这个q,就是pseudo targets,q不再是one hot label,而是softmax score,所以这里计算KL divergence而不是cross entropy。

最终ALBEF的训练loss有五个:两个ITC、两个MLM、一个ITM,其中:
1)ITC有两个loss:一个是原始的ITC,一个是基于pseudo target的ITC,所以分别加权(1-α)和α的loss weight,最终得到momentum版本的ITC loss。
2)MLM loss有两个loss:一个是原始的MLM,一个是基于pseudo target的MLM,用新生成的pseudo target去代替了原来的ground truth,分别加权(1-α)和α的loss weight,最终得到momentum版本的MLM loss。
3)ITM有一个loss:ITM没有动量版本,因为本身它就是基于ground truth,它就是一个二分类任务,而且在ITM里又做了hard negative,这跟momentum model有冲突。

  • 18
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值