3、Vision Transformer (ViT): 开启视觉识别的新纪元

文章介绍了Transformer在计算机视觉领域的创新应用,特别是ViT(VisionTransformer),它通过全局注意力机制、分块序列化处理和预训练-微调策略,改进了图像识别性能。文章详细描述了ViT的网络架构、实现步骤和一个简易代码示例。
摘要由CSDN通过智能技术生成

目录

一、论文题目

二、背景与动机

三、卖点与创新

四、具体实现细节

网络架构

具体实现步骤

简易代码 

五、一些资料


一、论文题目

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scaleicon-default.png?t=N7T8https://arxiv.org/abs/2010.11929

二、背景与动机

        在深度学习领域,卷积神经网络(CNN)长期以来一直占据着计算机视觉任务的主导地位,尤其在图像分类、目标检测等方面取得了显著成果。然而,CNN的设计基于局部感受野和权重共享等特性,对于全局信息的捕获存在一定的局限性。近年来,随着Transformer架构在自然语言处理(NLP)领域的巨大成功,研究者们开始思考如何将这种强大的序列建模能力引入到计算机视觉中。

三、卖点与创新

        Vision Transformer的主要创新点在于将Transformer架构直接应用于图像。

  1. 全局注意力机制:ViT的最大创新之处在于其采用了Transformer中的自注意力机制,摒弃了卷积层,能够对输入图像进行全局分析,这使得模型可以更好地理解和学习图像中的长距离依赖关系,从而提升模型性能。

  2. 分块序列化处理:ViT将整个图像分割成固定大小的 patches,并将这些patches线性嵌入后形成一个序列输入到Transformer中。这种“图像变序列”的处理方式,使得ViT可以直接利用NLP领域的成熟技术,在图像识别上表现出强大的适应性和灵活性。

  3. 预训练与微调:ViT借鉴了BERT等模型的预训练-微调策略,通过大规模无标注数据集进行预训练,然后在特定任务上进行微调,这一策略大大提升了模型的泛化能力和迁移学习效果。

四、具体实现细节

网络架构

具体实现步骤

  1. 输入图片shape为[224, 224, 3] [224,224,3]
  2. 通过Patch Embedding,即一个卷积 + Flatten(),生成shape为 [196,768] 的tokens
  3. 进行Class Embedding,即torch.concat(tokens, cls_token)
    cls_token为可训练参数,tokens的shape变化为: [196,768]→[197,768]
  4. Position Embedding,加上位置编码tokens,即tokens = tokens + pos_tokens
    pos_tokens为可训练参数,tokens的shape变化为: [197,768]→[197,768]
  5. 通过Dropout层
  6. 经过 若干层Transformer Encoder
  7. 经过LayerNorm层,shape为 [197,768]
  8. 提取Class Token所对应的输出
    这里的实现为切片,对 [197,768] 进行切片,只需要提取出Class Token对应的输出( [1,768] )即可
  9. 通过MLP Head得到最终的输出
    在自己训练集上使用,Pre-Logits就不要了,MLP Head就是nn.Linear
  10. 通过Softmax得到概率输出

简易代码 

import torch
from torch import nn
from torch.nn import functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.heads = heads
        self.head_dim = embed_dim // heads
        
        assert (
            self.head_dim * heads == embed_dim
        ), "embed_dim must be divisible by heads"
        
        # 分割嵌入向量到多个头部,每个头部有一个独立的线性层
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # 输出线性层将多头输出再次结合起来
        self.fc_out = nn.Linear(heads * self.head_dim, embed_dim)
    
    def forward(self, values, keys, queries, mask):
        N = queries.shape[0] # 批量大小
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # 分割输入以适应多头注意力机制
        # 维度从(N, seq_length, embed_dim)变为(N, seq_length, heads, head_dim)
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # 将嵌入向量通过线性层并计算自注意力矩阵
        # 注意力矩阵的维度为(N, heads, query_len, key_len)
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # 使用“einsum”进行自注意力机制的运算
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_dim ** (1 / 2)), dim=3)

        # 将注意力矩阵与值向量相乘并重新排列回原来的维度
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # 最终通过一个线性层得到输出
        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads, dropout, forward_expansion):
        super().__init__()
        self.attention = SelfAttention(embed_dim, heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # 前馈网络使用两个线性层
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion*embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_dim, embed_dim)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        
        # 添加skip connection, 然后进行层归一化和dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.1, emb_dropout=0.1):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.dim = dim
        self.depth = depth
        
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2
        
        # 位置编码和CLS token初始化
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches+1, dim))
        self.patch_to_embedding = nn.Linear(self.patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        
        # 创建多个Transformer块
        self.transformer = nn.Sequential(
            *[TransformerBlock(dim, heads, dropout, mlp_dim) for _ in range(depth)]
        )
        
        self.to_cls_token = nn.Identity()
        
        # MLP头部用于分类任务
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, num_classes)
        )
        
    def forward(self, x, mask=None):
        N, C, H, W = x.shape  # 图像维度
        # 切割图像成patches并平铺
        x = x.view(N, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
        x = x.permute(0, 2, 4, 3, 5, 1)  # N, n_patches, patch_size, patch_size, C
        x = x.flatten(2)  # N, n_patches, patch_dim
        x = x.flatten(1)  # N, n_patches * patch_dim
        
        patches = self.patch_to_embedding(x)  # N, num_patches, dim
        cls_tokens = self.cls_token.expand(N, -1, -1)  # N, 1, dim
        x = torch.cat((cls_tokens, patches), dim=1)  # N, num_patches+1, dim
        x += self.pos_embedding  # 加上位置编码
        x = self.dropout(x)

        x = self.transformer(x, mask)  # N, num_patches+1, dim
        
        x = self.to_cls_token(x[:, 0])  # N, dim
        return self.mlp_head(x)  # N, num_classes

# 设置超参数 (根据你的数据和资源进行调整)
img_size = 256  # 图像大小
patch_size = 32  # patch大小
num_classes = 10  # 分类类别数量
channels = 3  # 图像通道数
dim = 1024  # 嵌入向量的维度
depth = 6  # Transformer块的数量
heads = 8  # 注意力头的数量
mlp_dim = 2048  # 前馈网络的隐藏层维度
dropout = 0.1  # dropout比率
emb_dropout = 0.1  # embedding dropout比率

# 实例化模型
vit = ViT(
    img_size=img_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    channels=channels,
    dropout=dropout,
    emb_dropout=emb_dropout
)

# 创建一个假的图像张量 (批大小, 通道数, 高度, 宽度)
img = torch.randn(1, channels, img_size, img_size)

# 获得预测结果
preds = vit(img)  # 维度为 (1, num_classes)
print(preds)

五、一些资料

ViT (Visual Transformer) - 知乎Acknowledge论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale原论文对应源码: https://github.com/google-research/vision_transformerPyTorch实现代码: pytorch_classi…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/464920124

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值