一文看懂Vision Transformer(VIT)

论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
论文下载链接:https://arxiv.org/abs/2010.11929

前言

Transformer早在2020年就在NLP领域大放异彩,并通过BERT等无监督预训练技术将NLP推上一个新的高度。VIT受其启发,尝试将Transformer应用到CV领域,并通过JFT数据预训练,在ImageNet1K上能够达到88.55%的准确率。如今,Transformer在CV,NLP,多模态等领域均已达到最先进水平,值得大家学习并应用于工作中。接下来的时间,将记录我在CV领域大模型的学习历程,欢迎大家一起讨论,互相学习。
在这里插入图片描述

模型详解

ViT流程简述:ViT通过大kernel的卷积将输入图片分为多个patch,这些patch就如同NLP中的token,会被投影为固定长度的向量送入Transformer Encoder中。通过L层Transformer Encoder输出特征O。由于是分类任务,在输入的token中加入一个特殊的cls_token,该token与patch embedding产生的token一样,会在Transformer Encoder中与其他token计算qkv,建立全局语义,并经过全连接最后输出预测类别。
在这里插入图片描述
如上图所示,ViT分为以下几个步骤:

1.patch embedding
对于图像其数据格式为[N, C, H, W,],这里将其看作维度为[64,3,224,224]的输入。如下代码所示,首先通过self.projection(kernel=16,stride=16的conv2d)将图片(224x224)按照16x16大小的Patch进行划分,划分后会得到 ( 224 / 16 ) 2 ( 224 / 16 ) ^2 (224/16)2 =196个Patches,经过self.projection后输入变为维度为[64,768,14,14]的特征,由于kernel与stride较大,需要对输入进行补边self.adaptive_padding(x)。接着通过线性映射(x = x.flatten(2).transpose(1, 2))将特征[64,768,14,14]转换成[64,196,768]维度,patch embedding时,self.norm=None,因此这里特征没有进行归一化操作。以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过卷积得到一个长度为768的向量(后面都直接称为token,图片patch块[16, 16, 3]embedding为->[1,768]维度的特征)。

引入self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)),作为模型最后输出的识别特征,其与上面输出的特征进行cat,获得最终的特征(维度[64,197,768])。至此,ViT就将图像识别问题转换为一个seq2seq问题了。

class PatchEmbed(BaseModule):
    def __init__(self,
                 in_channels=3,
                 embed_dims=768,
                 conv_type='Conv2d',
                 kernel_size=16,
                 stride=16,
                 padding='corner',
                 dilation=1,
                 bias=True,
                 norm_cfg=None,
                 input_size=None,
                 init_cfg=None):
        super(PatchEmbed, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        if stride is None:
            stride = kernel_size

        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)

        if isinstance(padding, str):
            self.adaptive_padding = AdaptivePadding(
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding)
            # disable the padding of conv
            padding = 0
        else:
            self.adaptive_padding = None
        padding = to_2tuple(padding)

        self.projection = build_conv_layer(
            dict(type=conv_type),
            in_channels=in_channels,
            out_channels=embed_dims,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias)

        if norm_cfg is not None:
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
        else:
            self.norm = None

        if input_size:
            input_size = to_2tuple(input_size)
            # `init_out_size` would be used outside to
            # calculate the num_patches
            # e.g. when `use_abs_pos_embed` outside
            self.init_input_size = input_size
            if self.adaptive_padding:
                pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
                input_h, input_w = input_size
                input_h = input_h + pad_h
                input_w = input_w + pad_w
                input_size = (input_h, input_w)

            # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
            h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
                     (kernel_size[0] - 1) - 1) // stride[0] + 1
            w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
                     (kernel_size[1] - 1) - 1) // stride[1] + 1
            self.init_out_size = (h_out, w_out)
        else:
            self.init_input_size = None
            self.init_out_size = None

    def forward(self, x):
        if self.adaptive_padding:
            x = self.adaptive_padding(x)

        x = self.projection(x)
        out_size = (x.shape[2], x.shape[3])
        x = x.flatten(2).transpose(1, 2)
        if self.norm is not None:
            x = self.norm(x)
        return x, out_size

2.positional encoding
在步骤1patch embedding中,ViT将输入[64,3,224,224]embedding成维度为[64,197,768]的tokens,这个过程将二维图片降维成一维token,丢失了位置信息。为了在token中重新引入位置信息,ViT加入了位置编码。self.pos_embed是维度为[1,197,768]的可学习参数,通过trunc_normal_截断的正态分布初始化。

self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + self.num_extra_tokens,
                        self.embed_dims))
trunc_normal_(self.pos_embed, std=0.02)

这里需要注意,因为ViT需要在大数据集上进行预训练,下游任务中容易出现token的数量不一致的问题,即预训练时图片尺寸与下游任务图片尺寸不同,num_patches不一致,导致self.pos_embed位置编码维度不同。为了解决这个问题,ViT对self.pos_embed的参数进行插值,从而保证维度一致,且一定程度保留预训练的位置信息。卷积以及全连接可以处理不同的输入尺寸(channel不变),因此只需要将self.pos_embed的参数插值保证与num_patches维度一致,就可以将预训练模型迁移至下游任务中进行finetune。

def resize_pos_embed(pos_embed,
                     src_shape,
                     dst_shape,
                     mode='bicubic',
                     num_extra_tokens=1):
    if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
        return pos_embed
    assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
    _, L, C = pos_embed.shape
    src_h, src_w = src_shape
    assert L == src_h * src_w + num_extra_tokens, \
        f"The length of `pos_embed` ({L}) doesn't match the expected " \
        f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
        '`img_size` argument.'
    extra_tokens = pos_embed[:, :num_extra_tokens]

    src_weight = pos_embed[:, num_extra_tokens:]
    src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)

    dst_weight = F.interpolate(
        src_weight, size=dst_shape, align_corners=False, mode=mode)
    dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)

    return torch.cat((extra_tokens, dst_weight), dim=1)

3.Transformer Encoder
在这里插入图片描述

在步骤2中,ViT将token与位置编码sum获得输入Transformer Encoder的特征Embeded Patch。从图中我们知道,Transformer Encoder包含了3个组成部分,Norm,Multi-Head Attention以及MLP。这里Norm指的时layer norm(LN), Multi-Head Attention值self-attention中的多头注意力,MLP是两层全连接层。

在这里插入图片描述

LN:
ln是一种特征归一化方式,与BN类似,如上图所示,LN处理的特征维度是[C,H,W],即每幅图单独计算方差,期望,并归一化特征。强调一下,这里为什么使用LN。因为在NLP中,token长度不是一个定值,在一个batch中,可能出现不同长度的token,如果使用BN,统计batch中的每个channel维度的统计量,丢失了句子之间的关联性,且长度不一致需要补0,统计量随机性高,没有太大意义。而LN是对单个数据进行Norm与batch无关,token长度不影响LN计算统计量,且包含整句话语义,使句子关联性更强。这里使用LN是为了不改变transformer结构。

在这里插入图片描述
Multi-Head Attention
由于self-attention的计算参数有限(只有qkv3个全连接层),限制了模型的表达能力。为了提高模型复杂程度,其借鉴了CNN中的channel,即用多个self-attention将其堆叠起来,既不增加内存,又丰富了模型参数,并称之为Multi-Head Attention。

如下代码所示,x为特征Embeded Patch其维度是[64,197,768],通过全连接(self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias))转化为qkv(维度[64,197,2304]),因为qkv为三个特征,这里embed_dims需要×3。由于是多头注意力,这里需要将qkv reshape成(B, N, 3, self.num_heads, self.head_dims)维度,其中B是batchsize,N=197是token个数,self.num_heads=12为头数, self.head_dims=64是每个头的维度。由于qkv是通过一个全连接处理获得的,所以需要将其分开q, k, v = qkv[0], qkv[1], qkv[2]。
在这里插入图片描述

在这里插入图片描述
首先计算QK每一行内积,为了防止内积过大,需要除以 d k 1 / 2 d_k^{1/2} dk1/2。利用softmax将其激活成V的系数。这里每个token的V都会与其他token计算attention,所以transformer从一开始就具备了全局语义全局感受野,这是与CNN的最大差异(CNN需要堆叠大量的downsample才能让模型深处的特征具有大的感受野)。

class MultiheadAttention(BaseModule):
    """Multi-head Attention Module
    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads.
        input_dims (int, optional): The input dimension, and if None,
            use ``embed_dims``. Defaults to None.
        attn_drop (float): Dropout rate of the dropout layer after the
            attention calculation of query and key. Defaults to 0.
        proj_drop (float): Dropout rate of the dropout layer after the
            output projection. Defaults to 0.
        dropout_layer (dict): The dropout config before adding the shortcut.
            Defaults to ``dict(type='Dropout', drop_prob=0.)``.
        qkv_bias (bool): If True, add a learnable bias to q, k, v.
            Defaults to True.
        qk_scale (float, optional): Override default qk scale of
            ``head_dim ** -0.5`` if set. Defaults to None.
        proj_bias (bool) If True, add a learnable bias to output projection.
            Defaults to True.
        v_shortcut (bool): Add a shortcut from value to output. It's usually
            used if ``input_dims`` is different from ``embed_dims``.
            Defaults to False.
        init_cfg (dict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 input_dims=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 dropout_layer=dict(type='Dropout', drop_prob=0.),
                 qkv_bias=True,
                 qk_scale=None,
                 proj_bias=True,
                 v_shortcut=False,
                 init_cfg=None):
        super(MultiheadAttention, self).__init__(init_cfg=init_cfg)

        self.input_dims = input_dims or embed_dims
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.v_shortcut = v_shortcut

        self.head_dims = embed_dims // num_heads
        self.scale = qk_scale or self.head_dims**-0.5

        self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.out_drop = DROPOUT_LAYERS.build(dropout_layer)

    def forward(self, x):
        B, N, _ = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
                                  self.head_dims).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
        x = self.proj(x)
        x = self.out_drop(self.proj_drop(x))

        if self.v_shortcut:
            x = v.squeeze(1) + x
        return x

MLP:
将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768,在mmcls中MLP也称为FFN。

class FFN(BaseModule):
    """Implements feed-forward networks (FFNs) with identity connection.
    """

    @deprecated_api_warning(
        {
            'dropout': 'ffn_drop',
            'add_residual': 'add_identity'
        },
        cls_name='FFN')
    def __init__(self,
                 embed_dims=256,
                 feedforward_channels=1024,
                 num_fcs=2,
                 act_cfg=dict(type='ReLU', inplace=True),
                 ffn_drop=0.,
                 dropout_layer=None,
                 add_identity=True,
                 init_cfg=None,
                 **kwargs):
        super(FFN, self).__init__(init_cfg)
        assert num_fcs >= 2, 'num_fcs should be no less ' \
            f'than 2. got {num_fcs}.'
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.act_cfg = act_cfg
        self.activate = build_activation_layer(act_cfg)

        layers = []
        in_channels = embed_dims
        for _ in range(num_fcs - 1):
            layers.append(
                Sequential(
                    Linear(in_channels, feedforward_channels), self.activate,
                    nn.Dropout(ffn_drop)))
            in_channels = feedforward_channels
        layers.append(Linear(feedforward_channels, embed_dims))
        layers.append(nn.Dropout(ffn_drop))
        self.layers = Sequential(*layers)
        self.dropout_layer = build_dropout(
            dropout_layer) if dropout_layer else torch.nn.Identity()
        self.add_identity = add_identity

    @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
    def forward(self, x, identity=None):
        """Forward function for `FFN`.

        The function would add x to the output tensor if residue is None.
        """
        out = self.layers(x)
        if not self.add_identity:
            return self.dropout_layer(out)
        if identity is None:
            identity = x
        return identity + self.dropout_layer(out)

整个TransformerEncoderLayer就由上述几个部分组成,其中norm均为LN,且每层都增加了dropout。由于Transformer模型表达能力强,容易造成过拟合的现象,ViT中添加了大量dropout,droppath操作。

class TransformerEncoderLayer(BaseModule):
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.norm1_name, norm1 = build_norm_layer(
            norm_cfg, self.embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias)

        self.norm2_name, norm2 = build_norm_layer(
            norm_cfg, self.embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(
            embed_dims=embed_dims,
            feedforward_channels=feedforward_channels,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def init_weights(self):
        super(TransformerEncoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = self.ffn(self.norm2(x), identity=x)
        return x

一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出作为encoder的最终输出 ,代表最终的image presentation。ViT流程如下公式:
1.self.patch_embed(x)将图片embeding成token,添加cls_token并cat到x中。将位置编码 E p o s E_{pos} Eposresize成patch分辨率与x相加(x = x + resize_pos_embed)作为Transformer Encoder的输入 Z 0 Z_0 Z0
2. Z l Z_l Zl输入通过LN归一化 L N ( Z l ) LN(Z_l) LN(Zl),并经过多头注意力输出 Z l ′ Z'_l Zl;
3.多头注意力输出 Z l ′ Z'_l Zl经过FFN两层全连接并与其残差相加获得输出 Z l Z_l Zl
4.cls_token = x[:, 0],最后会将特殊字符cls对应的输出作为encoder的最终输出 ,代表最终的image presentation。
在这里插入图片描述

	def forward(self, x):
        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=self.num_extra_tokens)
        x = self.drop_after_pos(x)

        if not self.with_cls_token:
            # Remove class token for transformer encoder input
            x = x[:, 1:]

        outs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)

            if i == len(self.layers) - 1 and self.final_norm:
                x = self.norm1(x)

            if i in self.out_indices:
                B, _, C = x.shape
                if self.with_cls_token:
                    patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
                    patch_token = patch_token.permute(0, 3, 1, 2)
                    cls_token = x[:, 0]
                else:
                    patch_token = x.reshape(B, *patch_resolution, C)
                    patch_token = patch_token.permute(0, 3, 1, 2)
                    cls_token = None
                if self.output_cls_token:
                    out = [patch_token, cls_token]
                else:
                    out = patch_token
                outs.append(out)

        return tuple(outs)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值