论文阅读笔记:Vision Transformer (ViT)

1. Vision Transformer

Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).

这是一篇奠定了Transformer在视觉领域击败传统卷积的文章,Transformer在NLP领域大放异彩之后,在视觉领域也取得了优异的效果,作者摒弃了所有的卷积操作,将图片分割为若干patch,再进行编码,像文本序列一样输入进Transformer模型中。在中等规模的数据集上取得的效果并没有卷积的效果好,但是在大规模的数据集上的表现已经能够超越卷积。

在这里插入图片描述
假设一张图片的尺寸为 H × W × C H \times W \times C H×W×C,patch的尺寸为 P × P P \times P P×P, 那么划分后的图片可以表示为 N × ( P 2 × C ) N \times (P^2 \times C) N×(P2×C), 其中 N = ( H × W ) / P 2 N = (H \times W) / P^2 N=(H×W)/P2。那么一个patch的初始编码长度就等于 ( P 2 × C ) (P^2 \times C) (P2×C)对其进行线性投影和位置编码之后就可以像训练文本一样。此外,如图所示,输入进网络中的patch有九个,但是对于最后用哪个编码结果进行图像分类是很难决定的,于是在网络中额外输入一个用于分类的cls_token,它的维度与patch是一致的,我们可以认为它是一个用于最终分类的人为添加的patch。

注意力机制并不是第一次用于图像处理中,SE(sequeeze and excitation)块其实也是一种注意力机制,不过它是作用于通道维的,而ViT是作用于全局的。每个patch都能与任意通道的patch做注意力。其实,当patch的形状是1x1时,效果就和SE块很类似了。

在这里插入图片描述
实验参数设置如图所示,可以看到参数量很大。

2. 代码

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torchvision
from torch.utils import data
import matplotlib.pyplot as plt
import copy
import math

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):

    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(normalized_shape=dim)
        self.fn = fn

    def forward(self, x):
        x = self.norm(x)
        x = self.fn(x)
        return x

class FeedForward(nn.Module):

    def __init__(self, dim, hidden_dim, dropout=0.):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=dim, out_features=hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=hidden_dim, out_features=dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.net(x)
        return x

class Attention(nn.Module):

    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = heads * dim_head
        project_out = not(heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** (-0.5)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # x [batch_size, 查询个数, dim]
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t:rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # q,k,v维度相等 [batch_size, num_heads, 查询个数, d]
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([
                    PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                    PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
                ])
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):

    def __init__(self,image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super(ViT, self).__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )  # h * w 等于patch的数量

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img [batch_size, c, H, W]
        x = self.to_patch_embedding(img)  # [batch_size, num_patch, dim]
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)  # [batch_size, 1, dim]
        x = torch.cat((cls_tokens, x), dim=1)  # [batch_size, 1 + num_patch, dim]
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)

net = ViT(image_size=(224, 224), patch_size=(32, 32), num_classes=10, dim=768, depth=12, heads=12, mlp_dim=3072)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Flatten Transformer是一种使用Focused Linear Attention的Vision Transformer。它采用了类似于传统Transformer的self-attention结构,但在关注机制上进行了改进。具体来说,Flatten Transformer使用了Focused Linear Attention来代替传统的self-attention。Focused Linear Attention通过将注意力权重分配给图像的不同区域,使得模型能够更加关注重要的图像特征。 在Flatten Transformer中,图像首先被拆分成小块(patch),然后通过一个Embedding层转换成token。这个Embedding层将图像数据转换成一个向量,其形状为[num_token, token_dim,以适应Transformer Encoder的要求。接下来,Flatten Transformer使用Focused Linear Attention来计算每个token之间的关联性,并根据计算得到的注意力权重对它们进行加权求和。最后,经过Transformer Encoder和MLP Head的处理,模型可以输出对图像进行分类的结果。 关于Flatten Transformer的详细结构和实现,你可以参考引用中提供的论文和引用中提供的GitHub代码。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [狗都能看懂的Vision Transformer的讲解和代码实现](https://blog.csdn.net/weixin_42392454/article/details/122667271)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值