《Vision Transformer (ViT)》论文精度,并解析ViT模型结构以及代码实现

《AN IMAGE IS WORTH 16X16 WORDS:

TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》

论文共有22页,表格和图像很多,网络模型结构解释的很清楚,并且用四个公式展示了模型的计算过程;本文章对其进行精度,并对源码进行剖析,希望读者可以耐心读下去。

论文地址:https://arxiv.org/abs/2010.11929

源码地址(pytorch):https://github.com/lucidrains/vit-pytorch


目录

一、引言

二、结论

三、ViT模型结构

四、代码构建ViT模型(注释基本都在代码中)

(1)标准的Transformer编码器

(2)ViT模型

五、论文中的图像和表格分析

(1)ViT的模型变体

(2)分类精度结果对比

(3)数据集的大小对ViT的影响

(3)BiT、ViT、Hybrids模型集的比较

(4)ATTENTION DISTANCE

(5)学习率的影响

(6)位置嵌入的比较

(7)注意力权重可视化


一、引言

虽然Transformer架构已经成为自然语言处理任务的事实上的标准,但它在计算机视觉中的应用仍然有限。在视觉上,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组成部分,同时保持其整体结构的到位。我们表明,这种对CNN的依赖是不必要的,直接应用于图像块序列的Transformer可以很好地执行图像分类任务。当对大量数据进行预训练并转移到多个中小规模的图像识别基准时(ImageNet、CIFAR-100、VTAB等),Vision Transformer ( ViT ) 相对于先进的卷积网络获得了优异的结果,同时训练所需的计算资源也大大减少。

二、结论

我们探索了Transformer在图像识别中的直接应用不同于以往在计算机视觉中使用自注意力的工作,我们除了初始的patches提取步骤外,并没有引入图像特定的诱导偏差。相反,我们将图像理解为一个patches序列,并通过标准的 Transformer 编码器对其进行处理,如同在NLP中使用一样。这种简单而又可扩展的策略在结合大规模数据集的预训练时表现出惊人的效果。因此,Vision Transformer 在许多图像分类数据集上匹配或超过了艺术状态,同时相对于训练前来说相对便宜。

虽然这些初步成果令人鼓舞,但仍存在许多挑战。一是将ViT应用到其他计算机视觉任务中,如检测和分割。我们的结果,再加上 《End-to-end object detection with transformers》(2020) 的结果,表明这一方法的前景。另一个挑战是继续探索自监督的预训练方法。我们的初始实验从自监督预训练方面表现出了改进,但自监督与大规模监督预训练相比仍有较大差距。最后,ViT 的进一步缩放很可能导致性能的提高。

三、ViT模型结构

原文:我们将图像分割成固定大小的块,线性嵌入其中的每个块,加入位置嵌入,并将得到的向量序列反馈给标准的Transformer编码器。为了执行分类,我们采用在序列中添加额外可学习的“classification token”的标准方法。

结合原文,把模型的结构解析成以下三个步骤:

(1)把大小为256*256的图像切割成16*16的patches,每一个patch的像素大小是14*14

 ​​​​​​(2)对一张图像进行切割、位置嵌入、分类嵌入的操作,返回一个大小为197*768的矩阵,准备送入到标准的Transformer编码器中。

 (3)本文章举例处理了一张图像,即batch_size=1。[1, 197, 768]经标准的Transformer编码器输出[1, 197, 768],然后经过分类头,输出分类结果。

 以上三个步骤的计算公式,可以由下面的四个公式概括:

如下图所示,右侧是ViT模型的结构图,左边是对各个公式的参数解释,其对应右侧的各个模型块 

 四、代码构建ViT模型(注释基本都在代码中)

构建ViT模型的结构前,需要搭建一个标准的Transformer编码器结构。

一个标准的Transformer编码器结构包括:多头注意力机制、层规范化和残差连接和多层感知机(简单的神经网络)。

******如果对Transformer的结构不是很熟悉,可以参考我的 “变形金刚 Transformer” 专栏******

(1)标准的Transformer编码器

层规范化:

class PreNorm(nn.Module):
    """层规范化"""
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

多层感知机:

class FeedForward(nn.Module):
    """前馈神经网络(MLP)"""
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

多头注意力机制:

class Attention(nn.Module):
    """多头注意力机制"""
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

Transformer编码器:

class Transformer(nn.Module):
    """标准的Transformer编码器"""
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        # 实现残差网络的相加
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

(2)ViT模型

ViT模型结构可以分成三个部分:

  • 首先对图像进行预处理,分成固定大小的patches后,平坦化并进行位置和分类嵌入;
  • 然后经过标准的Transformer编码器输出自注意权重矩阵;
  • 最后经过LN和MLP进行图像分类。
class ViT(nn.Module):
    def __init__(self, *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 pool='cls',
                 channels=3,
                 dim_head=64,
                 dropout=0.,
                 emb_dropout=0.
                 ):
        super().__init__()
        # 图像尺寸
        image_height, image_width = pair(image_size)
        # patch尺寸
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0,  'Image dimensions must be divisible by the patch size.'
        # patches数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 一个patch的维度
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 处理原始图像
        # Rearrange:按照字符串的含义对目标进行重新排列的操作
        # img = torch.randn(1, 3, 256, 256)  '1 3 (256 32) (256 32) -> 1 (256 256) (32 32 3)'
        # 分成32*32个patch,并平铺成1024个patch,每一个patch的大小为8*8
        # 然后经过一个全连接层(3072, 1024)输出处理后的图像(三维:32*32*3=3072)
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )
        # Sequential(
        #   (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
        #   (1): Linear(in_features=3072, out_features=1024, bias=True)
        # )
        # print(self.to_patch_embedding)

        # 位置嵌入(patches+cls的位置信息)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 定义标准Transformer模型结构
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool

        # 不区分参数的占位符标识运算符。
        # identity模块不改变输入,直接return input
        # 一种编码技巧吧,比如我们要加深网络,有些层是不改变输入数据的维度的,这时就可以使用此函数
        # 这个网络层的设计是仅用于占位的,即不干活,只是有这么一个层,放到残差网络里就是在跳过连接的地方用这个层,显得没有那么空虚!
        self.to_latent = nn.Identity()

        # 定义MLP分类头的模型结构
        # 首先经过LN,然后经过一个全连接层(1024, 1000),输出分类结果
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img = torch.randn(1, 3, 256, 256)
        # 即往网络中送入一张256*256的三维图像(batch_size=1)
        x = self.to_patch_embedding(img)
        # 平铺后,共有8*8个patch,其大小为32*32
        print('平铺:', x.size())  # torch.Size([1, 64, 1024])
        b, n, _ = x.shape
        # print(b, n)  # 1 64

        # class tokens
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b=b)
        # print(cls_tokens.size())  # torch.Size([1, 1, 1024])

        x = torch.cat((cls_tokens, x), dim=1)  # 沿着dim=1方向对cls_token和x进行拼接
        x += self.pos_embedding[:, :(n + 1)]  # 拼接后的x与嵌入未知信息
        x = self.dropout(x)
        print(x.size())  # torch.Size([1, 65, 1024])

        # 把预处理的图像送入Transformer模型中
        x = self.transformer(x)
        print(x.size())  # torch.Size([1, 65, 1024])

        # 如果pool=='mean',返回dim=1方向上的元素平均值
        # 否则,直接返回dim=0方向上的第一行的所有元素,即class tokens
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        print(x.size())  # torch.Size([1, 1024])

        # 不区分参数的占位符标识运算符。
        x = self.to_latent(x)

        # 返回分类头的分类结果
        return self.mlp_head(x)

测试代码:输入指定参数,调用ViT模型,输出分类矩阵:

输入参数含义:

  • image_size=256,       # 图像尺寸
  • patch_size=32,          # patches大小
  • num_classes=1000,  # 标签数量
  • dim=1024,                 # patch的维度
  • depth=6,                    # 模型深度(编码器的数量)
  • heads=16,                 # 注意力头的数量
  • mlp_dim=2048,         # Transformer中MLP的输出维度
  • dropout=0.1,              # Transformer中MLP的舍弃率
  • emb_dropout=0.1      # 嵌入舍弃率
import torch
from vit_pytorch import ViT


def test():
    v = ViT(
        image_size=256,  # 图像尺寸
        patch_size=32,   # patches数量
        num_classes=1000,# 标签数量
        dim=1024,        # patch的维度
        depth=6,         # 模型深度(编码器的数量)
        heads=16,        # 注意力头的数量
        mlp_dim=2048,    # Transformer中MLP的输出维度
        dropout=0.1,     # Transformer中MLP的舍弃率
        emb_dropout=0.1  # 嵌入舍弃率
    )

    img = torch.randn(1, 3, 256, 256)
    preds = v(img)

    # 如果preds.shape != (1, 1000),自行中断程序,并报错:'correct logits outputted'
    assert preds.shape == (1, 1000), 'correct logits outputted'

    return preds.shape


VIT_result = test()
print(VIT_result)  # torch.Size([1, 1000])

五、论文中的图像和表格分析

(1)ViT的模型变体

下面的表格表示的是ViT的模型变体,分别是ViT-B、ViT-L、ViT-H。

我们基于用于BERT《BERT: Pre-training of deepbidirectional transformers for language understanding》的ViT配置,如下表所示。“Base”和“Large”模型直接采用BERT,我们加入了较大的“Huge”模型。

下面我们使用简要说明来表示模型大小和输入patch大小:例如,ViT-L / 16表示输入patch大小为16×16的 “Large” 变体。注意,Transformer的序列长度与patch大小的平方成反比,因此具有较小patch大小的模型计算代价更高。

(2)分类精度结果对比

我们发现,大规模的训练会克服归纳性偏差。我们的视觉变形金刚( ViT )在足够规模的预训练后转移到数据量较少的任务时取得了很好的效果。在公共ImageNet-21k数据集或内部 JFT-300M 数据集上进行预训练时,ViT接近器在多个图像识别基准上击败了艺术状态。特别是最佳模型在ImageNet上达到了88.55%的准确率,在ImageNet-ReaL上达到了90.72%,在CIFAR-100上达到了94.55%。

在JFT - 300M上预训练的较小的 ViT - L / 16 模型在所有任务上都优于 BiT - L ( 在同一数据集上进行预训练 )模型,同时训练所需的计算资源大大减少。更大的模型 ViT-H / 14 进一步提高了性能,特别是在更具挑战性的数据集- ImageNet、CIFAR-100和VTAB套件上。有趣的是,该模型对预训练的计算量仍然大大低于现有的技术状态。但是,我们注意到,预训练效率不仅可能受到体系结构选择的影响,还可能受到其他参数的影响,如训练计划、优化器、权值衰减等。最后,在公开的ImageNet-21k数据集上预训练的 ViT-L / 16 模型在大多数数据集上也表现良好,同时占用了很少的资源进行预训练:它需要使用8个核心的标准云TPUv3训练大约30天。

下面的表格表示微调不同ResNet中Adam、SGD的消融实验:

(3)数据集的大小对ViT的影响

  • 左侧的图像表示:在小数据集上进行预训练时,大型ViT模型的表现要比BiTResNets 差,但在较大数据集上进行预训练时,ViT会发光。随着数据集的增长,较大的ViT变体会取代较小的ViT变体。
  • 右侧的图像表示:ResNets在较小的预训练数据集上表现更好,但ViT在较大的预训练数据集上表现更好。

(3)BiT、ViT、Hybrids模型集的比较

预训练的模型集:

  • BiT:7 ResNets, R50x1,R50x2 R101x1, R152x1, R152x2, pre-trained for 7 epochs, plus R152x2 and R200x3 pre-trainedfor 14 epochs;
  • ViT: 6 Vision Transformers, ViT-B/32, B/16, L/32, L/16, pre-trained for 7 epochs, plusL/16 and H/14 pre-trained for 14 epochs;
  • Hybrids:5 hybrids, R50+ViT-B/32, B/16, L/32, L/16 pre-trained for 7 epochs, plus R50+ViT-L/16 pre-trained for 14 epochs 

如下图所示:比较不同架构的性能与预训练计算:ViT,ResNets和混合Transformer的比较。在相同的计算开销下,ViT的性能一般优于ResNet。对于较小的模型尺寸,混合Transformer比纯Transformer有所改善,而对于较大的模型尺寸,纯Transformer比混合Transformer有所改善。

启发:这个结果有些令人惊讶,因为人们可能期望卷积局部特征处理能够在任何大小上辅助ViT。其次,ViT似乎没有在所尝试的范围内饱和,从而激励了未来的缩放努力(主要的缩放策略便是对模型的宽度(w)、深度(d)和分辨率(r)进行调整)。

(4)ATTENTION DISTANCE

为了了解ViT如何使用自注意来整合图像中的信息,我们分析了不同层次的注意权重所跨越的平均距离。这种 “注意力距离” 类似于CNN的感受野大小。平均注意距离在较低层的头之间是高度可变的,有的头会关注图像的大部分,有的头会关注查询位置处或附近的小区域。随着深度的增加,所有注意力头的注意距离增加。在网络的后半部分,大多数的heads通过tokens广泛参加。

如下图所示,按heads和网络深度划分参加区域的规模。通过平均查询像素与所有其他像素之间的距离,加权得到的注意力权重,对128幅样本图像计算了注意距离。每一个点显示了16个heads中的1个head在1层Transformer中的平均注意距离。图像宽度为224像素。

 (5)学习率的影响

分类头:这个设计继承了Transformer模型的文本,我们在整个主论文中使用它。最初尝试只使用图像块嵌入,全局平均池( GAP )它们,随后使用线性分类器(就像ResNet的最终特征映射一样)的性能很差。但是,我们发现性能差既不是由于额外的token,也不是GAP操作造成的。相反,他在表现上的差异完全由不同的学习率影响。

 (6)位置嵌入的比较

下表总结了本消融研究在 ViT-B / 16 模型上的结果。正如我们所看到的,虽然没有位置嵌入的模型和有位置嵌入的模型的性能有很大的差距,但是不同的位置信息编码方式之间几乎没有差别。我们推测,由于我们的Transformer编码器工作在 patch-level 的输入上,相对于像素级,在如何编码空间信息方面的差异并不重要。更确切地说,在 patch-level 输入中,空间维度远小于原始的像素级输入,如14 × 14而不是224 × 224,对于这些不同的位置编码策略,学习表示空间关系同样容易。

即便如此,网络学习到的位置嵌入相似度的具体模式取决于训练超参数。 

 (7)注意力权重可视化

认真阅读完毕,定会有所收获!

>>>如有疑问,欢迎评论区一起探讨!

  • 12
    点赞
  • 112
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
ViTVision Transformer)是一种用于计算机视觉任务的Transformer模型。它在处理图像数据时,将图像划分为一系列的图像块,然后将这些图像块转换为序列数据,并使用Transformer编码器对其进行处理。ViT利用了Transformer的自注意力机制,通过学习将图像块之间的关系建模,从而实现对图像的特征提取和表征学习。 ViT模型的核心思想是引入了位置嵌入(position embedding)来为序列数据引入位置信息。位置嵌入是Transformer模型中的一部分,它可以将每个序列元素与其在原始图像中的位置相关联。这样,模型就可以利用位置信息来捕捉图像中不同区域的上下文关系。关于Transformer位置嵌入的详细信息,可以参考中的《【机器学习】详解 Transformer_闻韶-CSDN博客_机器学习transformer》的解读。 另外,关于ViT的更多研究论文和应用实例,可以参考中的GitHub资源,该资源收集了一些关于Transformer计算机视觉结合的论文。同时,中的《机器学习》也提供了对Transformer编码器结构的详细解释,可以进一步了解Transformer模型的工作原理。 总结起来,ViT是一种通过将图像转换为序列数据,并利用Transformer模型进行特征提取和表征学习的方法。它利用位置嵌入来引入图像中不同区域的位置信息,并通过自注意力机制来建模图像块之间的关系。通过研究论文和资源,我们可以深入了解ViT模型的原理和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Flying Bulldog

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值