VLP、多模态图文任务(2)实例解读

    

        上一篇我们介绍了图像文本领域中的通用模型架构和流行的预训练任务。为了提供更具体的例子,我们选择了四个代表性模型作为案例研究,包括

  1. UNITER,一种基于OD的图像文本模型;
  2. ViLT,一种基于视觉Transformer的最小端到端图像文本模型;
  3.  ALBEF,一种使用对比和生成目标进行预训练的端到端图像文本模型;
  4.  SimVLM,第一个大规模预训练的编码器-解码器图像文本模型,其预训练目标是简单的PrefixLM。

以下简要介绍它们的架构和预训练任务。

  • UNITER。UNITER的架构如图(a)所示。图像通过离线预训练的目标检测模型进行编码,以提取区域特征。然后,将这些图像特征与输入文本的单词嵌入连接在一起,并添加位置嵌入,随后通过几个Transformer层进行多模态融合。该模型通过常用的任务进行预训练,包括遮蔽语言建模、图像-文本匹配和遮蔽区域建模。作者还通过使用最优传输提供了单词-区域对齐损失。多模态Transformer是通过预训练的BERT-base或BERT-large模型初始化的。
  • ViLT。(b)说明了ViLT的模型架构,这是一个最简单的图像文本模型。图像被分成图块,并通过图块嵌入进行编码,文本通过单词标记嵌入进行编码。这些特征被连接并发送到Transformer中,该Transformer通过在ImageNet22k上进行监督预训练的普通视觉Transformer进行初始化。预训练是通过遮蔽语言建模、图像-文本匹配、匹配图块建模和词块对齐进行的。
  • ALBEF。如图(c)所示,ALBEF采用了一般的VLP架构,这也在METER中广泛研究。具体而言,使用视觉Transformer对图像进行编码,使用BERT模型的前6层对文本进行编码,然后通过BERT模型的最后6层进行多模态融合。关键创新在于在预训练过程中使用对比目标,这在CLIP中引入,但尚未用于基于融合编码器的图像文本模型。通过将对比损失纳入预训练,可以通过两个特征向量的简单点积实现快速的图像-文本检索,而需要深度多模态融合的VQA和图像字幕任务也可以通过顶部的融合层进行处理。
  • SimVLM 如图(d)所示。CLIP和ALIGN是第一款大规模预训练的双编码器模型,仅适用于(零样本)图像分类和检索,而SimVLM是第一款可以用于需要深度多模态融合的任务的大规模预训练编码器-解码器模型。此外,预训练目标被简化为单一的PrefixLM损失。该模型在训练大型图像文本模型方面显示出巨大的潜力。

一: VLP基本操作代码解读

在处理图像和文本上许多模型是相同的,对于CV的处理初始图像的方法有OD、ViT

1.1 CV处理方法

1.1.1OD

1.1.2 ViT

        Vision Transformer是Transformer的视觉版本。将Transformer应用于图像图块(patch)序列上,打破了这种NLP与CV的隔离。简单来理解,Vision Transformer就是将输入进来的图片,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列,将组合后的结果传入Transformer特有的Multi-head Self-attention进行特征提取。最后利用Cls Token进行分类。

ViT分为两部分,一部分是特征提取部分,另一部分是分类部分。

特征提取部分

        特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。P前部分作用主要是对输入进来的图片分块,每隔一定的区域大小划分图片块,再对图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,通过自注意力机制,关注每个图片块的重要程度。

        如何分块?使用卷积。由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了

         在VIT中,我们常设置这个卷积的卷积核大小和步长都是16x16,此时卷积就会每隔16个像素点进行一次特征提取,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。(224=26*14,16*16*3=768,有种变相保存的感觉,两维表示特征提取,一维表示保存原图信息)

        下一步就是将这个特征层组合成序列,就是将高宽维度进行平铺,14, 14, 768在高宽维度平铺后,获得一个196(14*14), 768的特征层。平铺完成后,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,图中的这个0*就是Cls Token,我们此时获得一个197, 768的特征层。这里和文本处理最不同的就是将图块切分成序列的形式,代码如下:

class PatchEmbed(nn.Module):
    def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
        super().__init__()
        self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
        self.flatten        = flatten

        self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        super().__init__()
        #-----------------------------------------------#
        #   224, 224, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
        num_patches         = (224 // patch_size) * (224 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

        #--------------------------------------------------------------------------------------------------------------------#
        #   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
        #
        #   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
        #   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
        #   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
        #--------------------------------------------------------------------------------------------------------------------#
        #   196, 768 -> 197, 768
        self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))
        #--------------------------------------------------------------------------------------------------------------------#
        #   为网络提取到的特征添加上位置信息。
        #   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
        #   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
        #--------------------------------------------------------------------------------------------------------------------#
        #   197, 768 -> 197, 768
        self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))

    def forward_features(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1) 
        x = torch.cat((cls_token, x), dim=1)
        
        cls_token_pe = self.pos_embed[:, 0:1, :]
        img_token_pe = self.pos_embed[:, 1: , :]

        img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
        img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
        pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)

        x = self.pos_drop(x + pos_embed)

其他的部分和Transformer Encoder基本一致。

分类部分

        在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs      = (drop, drop)

        self.fc1    = nn.Linear(in_features, hidden_features)
        self.act    = act_layer()
        self.drop1  = nn.Dropout(drop_probs[0])
        self.fc2    = nn.Linear(hidden_features, out_features)
        self.drop2  = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1      = norm_layer(dim)
        self.attn       = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2      = norm_layer(dim)
        self.mlp        = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path  = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

2.1 NLP处理方法

        目前对CV和NLP的处理都是基于Transformer 进行变种操作,这里不进行展开,具体的细节区分会单独出一篇博客介绍语言模型(LM)预训练。Transfomer代码介绍见:pytorch7——模型层之Transformer-CSDN博客

二: 四个实例   

1. UNITER代码解读

2. ViLT代码解读

3. ALBEF代码解读

3.1 数据

image_embeds = self.visual_encoder(image) 
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)  

text_output = self.text_encoder.bert(text.input_ids, attention_mask = text.attention_mask,                      
                                        return_dict = True, mode = 'text')            
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1) 

3.2 模型

4.3 损失函数

ITC:

# get momentum features
with torch.no_grad():
    self._momentum_update()
    image_embeds_m = self.visual_encoder_m(image) 
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
    image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
    text_output_m = self.text_encoder_m.bert(text.input_ids, attention_mask = text.attention_mask,                      
                                        return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

    sim_i2t_m = image_feat_m @ text_feat_all / self.temp 
    sim_t2i_m = text_feat_m @ image_feat_all / self.temp     

    sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
    sim_targets.fill_diagonal_(1)          

    sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
    sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets        

sim_i2t = image_feat @ text_feat_all / self.temp 
sim_t2i = text_feat @ image_feat_all / self.temp 
                     
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 

loss_ita = (loss_i2t+loss_t2i)/2

self._dequeue_and_enqueue(image_feat_m, text_feat_m)
  

ITM:

# forward the positve image-text pair
output_pos = self.text_encoder.bert(encoder_embeds = text_embeds, 
                                attention_mask = text.attention_mask,
                                encoder_hidden_states = image_embeds,
                                encoder_attention_mask = image_atts,      
                                return_dict = True,
                                mode = 'fusion',
                               )            
with torch.no_grad():
    bs = image.size(0)          
    weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)
    weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)

    weights_i2t.fill_diagonal_(0)
    weights_t2i.fill_diagonal_(0)

# select a negative image for each text
image_embeds_neg = []    
for b in range(bs):
    neg_idx = torch.multinomial(weights_t2i[b], 1).item()
    image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   

# select a negative text for each image
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_i2t[b], 1).item()
    text_embeds_neg.append(text_embeds[neg_idx])
    text_atts_neg.append(text.attention_mask[neg_idx])
text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   
text_atts_neg = torch.stack(text_atts_neg,dim=0)      

text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)     

image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
image_atts_all = torch.cat([image_atts,image_atts],dim=0)

output_neg = self.text_encoder.bert(encoder_embeds = text_embeds_all, 
                                attention_mask = text_atts_all,
                                encoder_hidden_states = image_embeds_all,
                                encoder_attention_mask = image_atts_all,      
                                return_dict = True,
                                mode = 'fusion',
                               )                         

vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
vl_output = self.itm_head(vl_embeddings)            

itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
                       dim=0).to(image.device)
loss_itm = F.cross_entropy(vl_output, itm_labels) 

MLM

##================= MLM ========================##                
input_ids = text.input_ids.clone()
labels = input_ids.clone()

probability_matrix = torch.full(labels.shape, self.mlm_probability)                    
input_ids, labels = self.mask(input_ids, self.text_encoder.config.vocab_size, image.device, targets=labels,
                              probability_matrix = probability_matrix) 

with torch.no_grad():
    logits_m = self.text_encoder_m(input_ids, 
                                   attention_mask = text.attention_mask,
                                   encoder_hidden_states = image_embeds_m,
                                   encoder_attention_mask = image_atts,      
                                   return_dict = True,
                                   return_logits = True,   
                                  )    
    
mlm_output = self.text_encoder(input_ids, 
                               attention_mask = text.attention_mask,
                               encoder_hidden_states = image_embeds,
                               encoder_attention_mask = image_atts,      
                               return_dict = True,
                               labels = labels,   
                               soft_labels = F.softmax(logits_m,dim=-1),
                               alpha = alpha
                              )                           
loss_mlm = mlm_output.loss  

3.4 优化器

4. SimVLM代码解读

只有这个论文里没有,找的代码链接:https://github.com/YulongBonjour/SimVLM

未完待续。。。。

参考:

神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解_pytorch官方提供的vit模型_Bubbliiiing的博客-CSDN博客

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值