详解VIT(Vision Transformer)模型原理, 代码级讲解

本文详细介绍了Google在ICLR上发布的VIT模型,它是首个在计算机视觉领域超越CNN和RNN的Transformer模型。文章重点阐述了VIT的结构,包括图像特征嵌入、Transformer编码器(含多头注意力机制)、MLP分类模块,以及模型的亮点和整体架构。
摘要由CSDN通过智能技术生成

一、模型简介

1. 论文地位:VIT模型(Vision Transformer),这是一篇Google于2021年发表在计算机视觉顶级会议ICLR上的一篇文章。它首次将Transformer这种发源于NLP领域的模型引入到了CV领域,并在ImageNet数据集上击败了当时最先进的CNN网络。这是一个标志性的网络,代表transformer击败了CNN和RNN,同时在CV领域和NLP领域达到了统治地位,此后基本在ImageNet排行榜上都是基于transformer架构的模型了。

2. 论文下载链接AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE  ;

3. 推荐复现的代码仓库,可以star我这个GitHub开源项目,对每行代码有详尽的注释:VIT模型详解

二、模型亮点及整体架构介绍

1. 整体架构:VIT首次将transformer模型运用到CV领域并且取得了不错的分类效果,模型原理图如图1所示。该图表示了VIT模型的整体架构,可以看出VIT只用了transformer模型的编码器部分,并未涉及解码器。其实VIT模型不难理解,只需要将其拆成三个部分(1.图像特征嵌入模块;2.transformer编码器模块;3.MLP分类模块)就可以很容易捋顺它的结构。

        1.1.图像特征嵌入模块:标准的VIT模型对图像的输入尺寸有要求,必须为224*224.图像输入之后首先是需要进行patch分块,一般设置patch的尺寸为16*16,那么一共能生成(224/16)*(224/16)=196个patch块。这部分内容在代码中如何实现呢?其实很简单,就是用一个卷积层就可以实现,其卷积核大小为patch size=16, 步长为patch size=16.

具体的代码如下所示,每行代码均有详细注释,展示了图像分块和特征嵌入的完整过程,嵌入之后的特征维度是[196,768],之后我们还需要加上位置编码和类别token,前者使用直接相加的方法,后者使用concat的方法,所以加上类别token后,特征的维度变化为[197,768]:        

class PatchEmbed(nn.Module):  # 继承nn.Module
    """
    所有注释均采用VIT-base进行说明
    图像嵌入模块类
    2D Image to Patch Embedding
    """
    # 初始化函数,设置默认参数
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()   # 继承父类的初始化方法
        # 输入图像的size为224*224
        img_size = (img_size, img_size)
        # patch_size为16*16
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        # 滑动窗口的大小为14*14, 224/16=14
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        # 图像切分的patch的总数为14*14=196
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        # 使用一个卷积层来实现图像嵌入,输入维度为[BatchSize,3,224,224],输出维度为[BatchSize,768,14,14],
        # 计算公式 size= (224-16+2*0)/16 + 1= 14
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 如果norm_layer为True,则使用,否则忽略
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        # 获取BatchSize,Channels,Height,Width
        B, C, H, W = x.shape  # 输入批量的图像
        # VIT模型对图像的输入尺寸有严格要求,故需要检查是否为224*224,如果不是则报错提示
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        # 将卷积嵌入后的图像特征图[BatchSize, 768, 14,14],从第3维度开始展平,得到BatchSize*768*196;
        # 然后转置1,2两个维度,得到BatchSize*196*768
        x = self.proj(x).flatten(2).transpose(1, 2)
        # norm_layer层规范化
        x = self.norm(x)
        return x

        1.2. Transformer Encoder:主要由LayerNorm层,多头注意力机制,MLP模块,残差连接这5个知识点模块构成组成。这是整个VIT模型最重要也是最需要花精力理解的地方,要搞清楚transformer的编码器部分,首先需要搞清楚多头注意力机制。

        由于注意力机制这块内容较多,引用博文供读者学习:狗都能看懂的self attention讲解 ;

 详解self attention以及 multi head self attention的原理 。后续将补上对于注意力机制这部分内容自己的理解,其实理解了注意力机制也就理解了transformer架构的基本原理。两种注意力机制的原理如下图所示。

        接下来我们按顺序往前讲解,LayerNorm层比较简单,直接调用pytorch中的nn.LayerNorm就可以实现。接下来看多头注意力机制的实现,如下所示。首先通过使用一个全连接层生成q,k,v的初始值,然后使用reshape和维度调换来进行调整,最后使用切片操作分别获得单独的Q,K,V,接下来就是transformer原始文章里提出的注意力机制的公式的实现了,公式如下:

class Attention(nn.Module):
    # 经过注意力层的特征的输入和输出维度相同
    def __init__(self,
                 dim,   # 输入token的维度
                 num_heads=8,  # multiHead中 head的个数
                 qkv_bias=False,  # 决定生成Q,K,V时是否使用偏置
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        # 设置多头注意力的注意力头的数目
        self.num_heads = num_heads
        # 针对每个head进行均分,它的Q,K,V对应的维度;
        head_dim = dim // num_heads
        # 放缩Q*(K的转置),就是根号head_dim(就是d(k))分之一,及和原论文保持一致
        self.scale = qk_scale or head_dim ** -0.5
        # 通过一个全连接层生成Q,K,V
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # 定义dropout层
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        # 通过一个全连接层实现将每一个头得到的注意力进行拼接
        self.proj = nn.Linear(dim, dim)
        # 使用dropout层
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim],其中,num_patches+1中的加1是为了保留class_token的位置
        B, N, C = x.shape   # [batch_size, 197, 768]

        # 生成Q,K,V;qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head],
        # 调换维度位置,permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # 将Q,K,V分离出来,切片后的Q,K,V的形状[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # 这个的维度相同,均为[BatchSize, 8, 197, 768]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        # 即文章里的那个根据q,k, v 计算注意力的公式的Q*K的转置再除以根号dk
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # 对得到的注意力结果的每一行进行softmax处理
        attn = attn.softmax(dim=-1)
        # 添加dropout层
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)    # 乘V,通过reshape处理将每个head对应的结果进行拼接处理
        # 使用全连接层进行映射,维度不变,输入[BatchSize, 197, 768], 输出也相同
        x = self.proj(x)
        # 添加dropout层
        x = self.proj_drop(x)
        return x

        然后是MLP模块,也就是transformer原文中的前馈网络(feed forward),这一部分其实比较简单,没什么可讲的,就是两个全连接层加上dropout层实现:

class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        # 注意当or两边存在None值时,输出是不为None那个
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        # MLP模块中的第一个全连接层,输入维度in_features, 输出维度hidden_features
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        # MLP模块中的第二个全连接层,输入维度hidden_features, 输出维度out_features
        x = self.fc2(x)
        x = self.drop(x)
        return x

        残差连接模块,这个比较简单,代码实现如下:

# 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
# 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
x = x + self.drop_path(self.attn(self.norm1(x)))
# 将输入x与经过layer norm和MLP处理后的值进行残差相加
x = x + self.drop_path(self.mlp(self.norm2(x)))

         好了,接下来是将这些不同功能的模块进行包装,代码和注释如下:

class Block(nn.Module):
    # 集成了Transformer Encoder的所有功能
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 定义全连接层,输入维度为embedding dim=768,隐藏层为embedding dim*4=3072,输出层为in_features=768
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        # 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
        # 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
        x = x + self.drop_path(self.attn(self.norm1(x)))
        # 将输入x与经过layer norm和MLP处理后的值进行残差相加
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

         其实讲到这里,基本已经理解了VIT模型知识体系的三分之二。接下来就是最后的MLP分类模块了,这一块比较简单,甚至可以只用一层全连接层来解决。之前没有了解过VIT的小伙伴,这里需要提示一下,我们输入到MLP类别分类器中的特征只有类别token。经过N层transformer编码器处理后的特征的维度与输入前相同,均为[197,768],我们只使用列表切片的方式提取出类别token,维度为[1,768].进行下一步的类别分类。有小伙伴可能不理解,那不是其它的特征没有用到吗?浪费了是不是。其实不是,多头注意力机制可以让不同位置的特征进行全面交互,这里输出的类别token和之前输入的类别token早已发生了巨变,这种变化是由其它特征影响的。

         最后提供一下,transformer模型的整体架构代码:

class VisionTransformer(nn.Module):
    # 集成VIT模型架构
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        # 创建可学习的位置token,形状为 (1, num_patches + self.num_tokens, embed_dim),初始值全为0
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        # nn.Sequential作为容器用来包装层,通过列表循环构建了depth个block
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer,即pre_logits层。这个if语句是问,如果需要pre_logits层并且不进行模型蒸馏则...
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            # 跳过pre_logits层
            self.pre_logits = nn.Identity()

        # Classifier head(s),使用一个全连接层进行分类结果输出
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init,对位置编码进行初始化
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)
        # 对类别编码进行初始化
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim],即[B, 196, 768]
        x = self.patch_embed(x)
        # 定义类别token,维度为[1,1,768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        # 注意dist_token是用于模型蒸馏的,此处设置为None
        if self.dist_token is None:
            # concat上类别token
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
        # 添加位置编码
        x = self.pos_drop(x + self.pos_embed)
        # 将特征输入transformer编码器
        x = self.blocks(x)
        # LayerNorm层
        x = self.norm(x)
        # 注意dist_token是用于模型蒸馏的,此处设置为None
        if self.dist_token is None:
            # 返回类别token
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        # 进行forward_features之后获得训练后的类别token
        x = self.forward_features(x)
        # head_dist等于None,直接执行else
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            # 使用全连接层输出分类结果
            x = self.head(x)
        return x

待更新!!!

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Trouville01

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值