【计算机视觉】ViT:代码逐行解读

一、代码

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

二、代码解读

2.1 大体理解

这段代码是一个实现了 Vision Transformer(ViT)模型的 PyTorch 实现。

ViT 是一个基于 Transformer 架构的图像分类模型,其主要思想是将图像分成一个个固定大小的 patch ,并将这些 patch 看做是一个个 token 输入到 Transformer 中进行特征提取和分类。

以下是对代码的解读:

  1. ViT类继承自nn.Module类,其构造函数有一系列参数,包括输入图像的尺寸、patch的大小、输出类别数、注意力机制中的头数等等。
  2. project_patches函数通过一个全连接层将每个patch映射到一个d维的特征空间中。
  3. 如果classification = True,则将一个额外的CLS token添加到输入的token序列的开头,即对于每张图像添加一个形状为[1, 1, d]的CLS token。同时,在ViT中采用的是绝对位置编码,因此还添加了一个1D的位置编码向量,其形状为[num_patches + 1, d],其中num_patches表示图像被划分成的patch数目。如果classification = False,则不添加CLS token。
  4. forward函数首先将输入的图像进行patch划分,并通过project_patches函数将每个patch映射到d维特征空间中。接着,将位置编码向量加到映射后的patch特征向量上,并进行dropout处理。如果classification=True,则在特征序列开头添加CLS token。接着将这些特征输入到Transformer中,进行特征提取。最后输出分类结果,如果classification=True,则只返回CLS token的分类结果。

2.2 详细理解

from self_attention_cv import TransformerEncoder

self_attention_cv是一个基于PyTorch实现的库,提供了在计算机视觉任务中使用自注意力机制的模块和网络,例如Transformer EncoderAttention Modules

它主要针对图像分类、对象检测、语义分割等任务,支持多种自注意力模块的实现,包括Simplified Self-AttentionFull Self-AttentionLocal Self-Attention等。此外,该库还提供了一些常见的计算机视觉任务模型的实现,例如Vision Transformer(ViT)Swin Transformer等。

TransformerEncoder是一个自注意力机制的编码器,用于将输入序列转换为编码后的序列。自注意力机制允许模型能够根据输入序列中的其他位置来加权计算每个位置的表示。这种机制在自然语言处理中的应用非常广泛,比如BERT、GPT等模型都采用了自注意力机制。

TransformerEncoder是基于PyTorch实现的,可以在计算机视觉任务中使用,例如图像分类、对象检测、语义分割等。它支持多头注意力、残差连接和LayerNorm等特性。在这个代码中,ViT模型中的Transformer部分采用了TransformerEncoder作为默认的实现。

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

这段代码定义了一个名为 ViT 的 PyTorch 模型类,它是一个使用自注意力机制(Self-Attention)实现的视觉 Transformer 模型。其中主要参数包括:

  • img_dim:输入图片的空间大小
  • in_channels:输入图片的通道数
  • patch_dim:将图片划分成固定大小的 patch 的大小
  • num_classes:分类任务的类别数
  • dim:线性层的维度,用于将每个 patch 投影到 MHSA 空间
  • blocks:Transformer 模型中的块数
  • heads:注意力头的数量
  • dim_linear_block:线性块内部的维度
  • dim_head:每个头的维度,如果没有指定则默认为 dim/heads
  • dropout:用于位置编码和 Transformer 的 dropout 概率
  • transformer:可选的 TransformerEncoder 类实例
  • classification:是否包含额外的 CLS 标记以用于分类任务
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

这里定义了 ViT 类的构造函数,其包含多个参数,包括输入图像大小 img_dim,输入通道数 in_channels,分块大小 patch_dim,分类数目 num_classes,嵌入维度 dim,Transformer编码器的块数 blocks,头数 heads,线性块的维度 dim_linear_block,注意力头维度 dim_head,Dropout概率 dropout,可选的Transformer编码器 transformer,以及是否进行分类的标志 classification

    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

这里检查 img_dim 是否能够被 patch_dim 整除,如果不能整除,则会引发断言错误。同时,将 patch_dim 存储到 self.p 中,并将是否进行分类的标志存储到 self.classification 中。

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

这里计算了输入图像中可分块的数量 tokens,并将每个块的维度 self.token_dim 设置为 in_channels * (patch_dim ** 2)

将嵌入维度 dim 存储到 self.dim 中,并根据 dim_head 是否为 None,设置注意力头维度 self.dim_headself.project_patches 是一个线性层,用于将每个块投影到嵌入空间中。

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

这里定义了嵌入层的Dropout层,并根据是否进行分类的标志,设置类别标记 self.cls_token、位置嵌入 self.pos_emb1DMLPself.mlp_head。如果不进行分类,则不需要 self.cls_tokenself.mlp_head

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout): 定义了一个dropout层,用于在embedding后对其进行dropout操作。

if self.classification:: 如果是分类任务,就执行下面的操作,否则跳过。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)): 定义了一个可训练参数cls_token,表示分类token,它是一个1x1xdim的tensor,其中dim表示embedding维度。

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)): 定义了一个可训练参数pos_emb1D,表示位置嵌入,它是一个(tokens+1)xdim的tensor,其中tokens表示图像被分成的patch数,dim表示embedding维度。

self.mlp_head = nn.Linear(dim, num_classes): 定义了一个全连接层,将embedding映射到输出类别的数量。

最后,根据传入的参数来选择使用默认的TransformerEncoder,还是使用传入的transformer。如果没有传入,则使用默认的TransformerEncoder,否则使用传入的transformer。

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

该方法的作用是将 Transformer 中的分类 token 扩展到整个批次的样本数。它接受一个 batch 参数作为批次大小,返回一个形状为 [batch, 1, dim] 的张量,其中 dim 是 Transformer 模型的维度大小。在这个方法中,使用了 PyTorch 的 expand() 方法来实现扩展操作。

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

forward 函数中,接收输入的 imgmask

通过 img_dimpatch_dim 计算出 tokens 数量,其中 tokens 为图像分割成的块的数量。

将输入的 img 分成 patch,并通过 rearrange 函数重组成形状为 [batch_size, tokens, patch_dim * patch_dim * in_channels] 的张量。

通过 Linear 层将每个 patch 映射到 dim 维度,并加上位置编码 pos_emb1D

如果是用于分类任务,则在序列的开头插入一个 CLS token,然后与处理后的 patch 张量按列拼接。

对 patch_embeddings 应用 dropout,并输入到 TransformerEncoder 中,返回输出张量 y,形状为 [batch_size, tokens, dim]

如果是用于分类任务,则从 y 中取出 CLS token,输入到一个 Linear 层中进行分类,输出分类结果。

如果不是分类任务,则直接返回 y。

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Transformer在许多NLP(自然语言处理)任务中取得了最先进的成果。 ViT (Vision Transformer)是Transformer应用于CV(计算机视觉)领域里程碑式的工作,后面发展出更多的变体,如Swin Transformer。 ViT (Vision Transformer)模型发表于论文An Image is Worth 16X16 Words: Transformer For Image Recognition At Scale,使用纯Transformer进图像分类。ViT在JFT-300M数据集上预训练后,可超过卷积神经网络ResNet的性能,并且所用的训练计算资源可更少。 本课程对ViT的原理与PyTorch实现代码精讲,来帮助大家掌握其详细原理和具体实现。其中代码实现包含两种代码实现方式,一种是采用timm库,另一种是采用einops/einsum。  原理精讲部分包括:Transformer的架构概述、Transformer的Encoder 、Transformer的Decoder、ViT架构概述、ViT模型详解、ViT性能及分析。  代码精讲部分使用Jupyter Notebook对ViT的PyTorch代码解读,包括:安装PyTorch、ViT的timm库实现代码解读、 einops/einsum 、ViT的einops/einsum实现代码解读。相关课程: 《Transformer原理与代码精讲(PyTorch)》https://edu.csdn.net/course/detail/36697《Transformer原理与代码精讲(TensorFlow)》https://edu.csdn.net/course/detail/36699《ViT(Vision Transformer)原理与代码精讲》https://edu.csdn.net/course/detail/36719《DETR原理与代码精讲》https://edu.csdn.net/course/detail/36768《Swin Transformer实战目标检测:训练自己的数据集》https://edu.csdn.net/course/detail/36585《Swin Transformer实战实例分割:训练自己的数据集》https://edu.csdn.net/course/detail/36586 《Swin Transformer原理与代码精讲》 https://download.csdn.net/course/detail/37045 
以下是一个简单的示例代码,用于实现Vision Transformer (ViT)的Transformer模型部分: ```python import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbedding(nn.Module): def __init__(self, image_size, patch_size, in_channels, embed_dim): super().__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.patch_embedding(x) x = x.flatten(2).transpose(1, 2) return x class Transformer(nn.Module): def __init__(self, embed_dim, num_heads, num_layers, hidden_dim, dropout): super().__init__() self.encoder_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout) for _ in range(num_layers) ]) def forward(self, x): for layer in self.encoder_layers: x = layer(x) return x class ViT(nn.Module): def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, num_layers, hidden_dim, num_classes): super().__init__() self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim) self.transformer = Transformer(embed_dim, num_heads, num_layers, hidden_dim, dropout=0.1) self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, x): x = self.patch_embedding(x) x = self.transformer(x) x = x.mean(1) x = self.classifier(x) return x ``` 这段代码定义了一个简单的Vision Transformer模型,包括PatchEmbedding模块、Transformer模块和ViT模型。你可以根据需要进修改和扩展。请注意,此代码只包括Transformer的模型部分,有关数据加载和训练的部分需要根据具体任务进实现。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旅途中的宽~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值