Vision Transformer (VIT)

一 VIT原理介绍:

AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE(论文名称)

        transformer是NLP的首选模型,同时transformer可以应用于图像处理。vit的出现挑战了卷积神经网络在计算机视觉的地位,在足够多的数据上进行预训练,我们不使用卷积神经网络而是将自然语言处理的模型transformer直接用于计算机视觉也能很好的解决问题。同时这篇论文打破了CV与NLP的壁垒。论文中指出当对大量数据进行预训练并将其传输到多个中小型图像识别基准(ImageNet、CIFAR-100、VTAB等)时,与最先进的卷积网络相比,视觉转换器(ViT)可以获得优异的效果,同时训练所需的计算资源大大减少。同时transformer不仅能够用于图像分类,也可以应用于目标检测与图像分割。

        首要考虑如何将图像转化成transformer可以接受的输入,如果我们采用每个像素来表示输入则会导致序列长度过长,复杂度提高。其本质就是随着像素点的增加,复杂度会成平方级增长。为解决这个问题,论文中方法就是将图像化成一个一个patch,意思就是原来是一个像素点代表一个token,现在是一大块的token一个patch作为一个token。

      上图就是VIT的模型架构图,第一步就是将图像切分为patch;第二步做一个Linear Projection;由于一个patch是一个正方形不能直接做为输入,因此将一个patch转化成固定维度的embeding做输入,将patch拉平成一维向量,并映射到transformer规定的emdebing size的纬度;第三步首先生成cls的tokenembeding,然后生成序列的位置编码,最后tokenembeding加上位置编码;第四步输入到transformer模型中;第五步输出cls多分类任务。

       左面是patch embeding 右图是position embeding。

        VIT中的transformer与原始的transformer结构还是有一定区别的,如上图所示,原始transformer中的Norm放在多头注意力机制后面的,而VIT中是放在多头注意力机制的前头。VIT中没有pad符号。

        对于位置编码文章中提出有一维的位置编码,二维的位置编码,相对位置编码,下图介绍的不同位置编码的对比。以上述九宫格图像为例,一维位置编码是[1,2,3,4........];二维位置编码是[1,1],[1,2],[1,3],[2,1]...... ;相对位置编码像个patch的距离既可以用绝对位置表示,也可以用相对位置表示,一个patch相对于已知位置的patch。

       如下图所示,将N个patch,x1p.......进行编码,然后加上Xclass,再加上位置编码变成Z0 ,然后就是循环L次进行多头注意力机制MSA,前馈神经网络MLP操作,并且在操作之前都要经过LN,然后惊醒残差链接。最后做分类任务。

        当我们使用不同大小的数据集的时候(ImageNet->ImageNet-21k->JFT-300M),模型效果的不同表现,在较小的数据集ImageNet下VIT的效果要比BIT效果要差,原因在于VIT没有先验知识,没有归纳偏置。ImageNet-21k数据集下VIT的效果与BiT的效果差不多了。在更大的数据集下JFT-300M,VIT的效果全面超过BiT的效果。

 

二 源码简介:

       下图的代码是VIT源码中patch_embed.py文件中代码,这一部分就是将一个patch转化成固定维度的embeding做输入,将patch拉平成一维向量,并映射到transformer规定的emdebing size的纬度。Rearrange来自einops库的函数作用就是改变张量的形状。

        elif linear_patch:
            patch_dim = channels * patch_height * patch_width
            #先将图片拉平,然后映射到规定的大小
            self.projection = nn.Sequential(
                Rearrange(
                    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                    p1=patch_height,
                    p2=patch_width,
                ),
                #Linear全连接层,patch_dim映射到embeding_dim
                nn.Linear(patch_dim, embedding_dim),
            )

       下图的代码是VIT源码中patch_embed.py文件中代码,其中包括构建cls_head和pos_embed。

        if linear_patch or conv_patch:
            self.grid_size = (
                image_height // patch_height,
                image_width // patch_width,
            )
            num_patches = self.grid_size[0] * self.grid_size[1]
            #nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter
            if cls_head:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
                num_patches += 1

            # positional embedding
            self.pos_embed = nn.Parameter(
                torch.zeros(1, num_patches, embedding_dim)
            )
            self.pos_drop = nn.Dropout(p=position_embedding_dropout)

      下图的代码是VIT源码中transformer.py文件中代码,对应原理中描述的VIT整体结构,有注意力机制,前馈神经网络构成,并且在多头注意力机制和前馈神经网络之前都进行layer normalization操作。

#nn.Sequential内部实现了forward函数,因此可以不用写forward函数。而nn.ModuleList则没有实现内部forward函数。nn.ModuleList 这些模块之间并没有什么先后顺序可言
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        #在多头注意力机制之前进行layer normalization
                        PreNorm(
                            dim,
                            Attention(
                                dim,
                                num_heads=heads,
                                qkv_bias=qkv_bias,
                                attn_drop=attn_dropout,
                                proj_drop=dropout,
                            ),
                        ),
                        # 在前馈神经网络之前进行layer normalization
                        PreNorm(
                            dim,
                            FeedForward(dim, mlp_dim, dropout_rate=dropout,),
                        )
                        if not revised
                        else FeedForward(
                            dim, mlp_dim, dropout_rate=dropout, revised=True,
                        ),
                    ]
                )
            )

     下图的代码是VIT源码中modules.py文件中代码,描述了多头注意力机制的过程。

 def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

论文地址:https://arxiv.org/pdf/2010.11929.pdf

源码:https://github.com/gupta-abhay/pytorch-vit

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值