Vision Transformer结构与代码详解

Vision Transformer整体结构如下
在这里插入图片描述
在图像分类中,Vision Transformer (ViT)主要包括图像分块与嵌入、位置编码、编码器、分类头几个阶段。在图像分类任务中,重点在于提取图像的整体特征以判断类别,编码器足以完成特征提取的任务,所以无需解码器。

以下详细解释每个阶段的作用原理

图像分块与嵌入

在这里插入图片描述
图像分块:将图像分割成多个固定大小的小块。例如,对于一张 224×224 像素的图像,分割成 14×14 个大小为 16×16 像素的小块。这样做是为了将图像转化为适合 Transformer 处理的序列形式,每个小块可以看作是序列中的一个元素,便于模型对图像局部信息进行处理和分析。
嵌入操作:将每个图像小块展平为一维向量,然后通过线性投影层映射到固定维度的向量空间,得到嵌入向量。这相当于为每个小块赋予一个特征向量表示,使得模型能够基于这些向量进行后续的计算和学习,将图像的像素信息转化为模型可理解的特征空间中的向量。

# 图像分块与嵌入
class PatchEmbedding(nn.Module):
    """
    PatchEmbedding 类通过卷积操作实现了图像的分块与嵌入,为后续的 Transformer 编码器提供了合适的输入格式.
    参数说明:
        image_size: 输入图像的尺寸。
        patch_size: 图像分块的大小。
        in_channels: 输入图像的通道数。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
    """
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2  # 图像被分割为 (224 // 16) ** 2 = 196 个块
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)  # 每个块被映射到 768 维度的空间

    def forward(self, x):  # (1, 3, 224, 224)
        x = self.proj(x)  # (1, 768, 14, 14)
        x = x.flatten(2).transpose(1, 2)  # (1, 196, 768)
        return x

位置编码

在这里插入图片描述
由于 Transformer 的自注意力机制本身无法捕捉位置信息,而图像中各小块的位置对于图像分类很重要,所以需要添加位置编码。位置编码是根据小块在图像中的位置确定的固定维度向量,将其与嵌入向量相加,能让模型学习到图像中各小块的相对位置关系,弥补自注意力机制在位置信息上的缺失,使模型能够利用位置信息来更好地理解图像内容。

# 位置编码
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding 类通过向输入张量添加一个可学习的位置编码参数,使得模型能够感知到序列中各个元素的位置关系.
    位置编码的作用: 在 Transformer 模型中,自注意力机制(Self-Attention)本身并不包含对序列顺序的感知能力.
                    因此,需要通过位置编码将序列中元素的位置信息显式地加入模型.
    可学习的位置编码: 与固定的位置编码(如正弦/余弦函数生成的位置编码)不同,这里的 pos_embedding 是一个可学习的参数,会在训练过程中优化.
    参数说明:
        num_patches: 图像分块的数量。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
    """
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))  # 在 pos_embedding 的定义中,额外增加了一个位置(num_patches + 1),用于容纳 [CLS] token 的位置编码。

    def forward(self, x):  # (1, 197, 768)
        x = x + self.pos_embedding  # 将 self.pos_embedding 添加到输入张量 x 上,以注入位置信息 (1, 197, 768) 
        return x

注意:[CLS] token 是 Transformer 模型中用于分类任务的一个特殊标记,它的作用是:
1 提供一个明确的位置,用于表示整个输入序列的全局特征。
2 通过 Transformer 编码器的学习,捕获输入数据的综合信息。
3 作为分类任务的输入特征,传递给全连接层以生成最终的预测结果。
4 在图像分类任务中,[CLS] token 的引入使得模型能够更高效地利用全局信息进行分类决策。

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

Transformer编码器

由多个相同的单个编码器层组成而成。
在这里插入图片描述

多头自注意力机制

允许每个图像小块的嵌入向量与其他所有小块的嵌入向量进行交互。通过计算查询向量(Query)、键向量(Key)和值向量(Value)之间的相似度(通常用点积计算),并根据相似度对其他小块的信息进行加权求和,得到该小块的更新表示。这样可以捕捉图像中的长距离依赖关系,让模型能够综合考虑图像中不同位置的信息,例如可以捕捉到图像中相隔较远的物体之间的关系。
多头自注意力机制的核心思想是通过计算输入序列中不同位置之间的关系,捕获全局依赖性。具体来说:
1 多头设计:
每个头独立计算注意力,能够从不同的子空间中提取信息。最终将多个头的输出拼接在一起,增加模型的表达能力。
2 缩放因子:
除以 sqrt(head_dim) 是为了稳定梯度,避免点积结果过大导致 softmax 数值不稳定。
3 应用场景:
在 Transformer 编码器中,多头自注意力机制用于捕捉输入序列中的长距离依赖关系。在图像分类任务中,它能够帮助模型理解图像 patches 之间的相互关系。

# 多头自注意力机制
class MultiHeadSelfAttention(nn.Module):
    """
    多头自注意力机制的核心思想是通过计算输入序列中不同位置之间的关系,捕获全局依赖性.
    经过 MultiHeadSelfAttention 的前向传播后,输出张量的形状仍为 (1, 197, 768),但每个位置的特征向量已经包含了与其他位置的关系信息。
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
    """
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"  # 确保嵌入维度可以被头数整除,否则无法均匀分配每个头的维度
        self.embed_dim = embed_dim  # 嵌入维度
        self.num_heads = num_heads  # 多头自注意力的头数
        self.head_dim = embed_dim // num_heads  # 每个头的嵌入维度
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)  # 定义一个线性层,用于生成查询(Q)、键(K)和值(V)
        self.out_proj = nn.Linear(embed_dim, embed_dim)  # 定义一个线性层,用于将多头注意力的输出投影回原始嵌入维度

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()  # 输入张量的形状,表示批量大小、序列长度和嵌入维度(1, 196+1, 768)
        qkv = self.qkv_proj(x)  # 使用线性层生成 Q、K、V,形状为 (batch_size, seq_length, 3 * embed_dim)=(1, 197, 3 * 768)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim).permute(0, 2, 1, 3)  # 将 Q、K、V 分割为多个头,并调整形状为 (batch_size, num_heads, seq_length, head_dim)=(1, 12, 197, 3 * 64)
        q, k, v = qkv.chunk(3, dim=-1)  # 将 Q、K、V 分离,每个形状为 (batch_size, num_heads, seq_length, head_dim)=(1, 12, 197, 64)

        # 输入形状:q.shape = (1, 12, 197, 64) 和 k.T.shape = (1, 12, 64, 197)  输出形状:attn_scores.shape = (1, 12, 197, 197)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # 计算 Q 和 K 的点积,得到注意力分数矩阵,以及除以 sqrt(head_dim) 进行缩放,避免数值过大导致 softmax 梯度消失
        attn_probs = F.softmax(attn_scores, dim=-1)  # 对注意力分数应用 softmax,得到归一化的注意力概率=(1, 12, 197, 197)

        # 输入形状:attn_probs.shape = (1, 12, 197, 197) 和 v.shape = (1, 12, 197, 64)  输出形状:attn_output.shape = (1, 12, 197, 64)
        attn_output = torch.matmul(attn_probs, v)  # 据注意力概率对 V 进行加权求和,得到注意力输出
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, embed_dim)  # (1, 197, 768)
        output = self.out_proj(attn_output)  # 使用线性层将多头注意力的输出投影回原始嵌入维度  # (1, 197, 768)
        return output 
前馈神经网络

自注意力机制的输出经过前馈神经网络进行进一步的特征变换。该网络通常由两个线性层和一个非线性激活函数(如 ReLU)组成,用于对自注意力机制提取的特征进行进一步的加工和提炼,提取更高级的特征表示,增强模型的表达能力。

# 前馈神经网络
class FeedForward(nn.Module):
    """
    前馈神经网络的作用是对多头自注意力机制的输出进行进一步处理,增强模型的表达能力.
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        hidden_dim: 前馈神经网络的隐藏层维度。
    """
    def __init__(self, embed_dim, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)  # 将输入张量从嵌入维度映射到隐藏维度=(1, 197, 3072)
        x = self.relu(x)  # 应用 ReLU 激活函数,引入非线性
        x = self.fc2(x)  # 将隐藏维度的输出映射回嵌入维度=(1, 197, 768)
        return x

单个编码器层

每个编码器层由多头自注意力机制MultiHeadSelfAttention和前馈神经网络FeedForward组成,并通过残差连接和层归一化LayerNorm增强模型的性能。

# 编码器层
class EncoderLayer(nn.Module):
    """
    EncoderLayer 类实现了 Transformer 编码器中的单个编码器层。每个编码器层由多头自注意力机制(MultiHeadSelfAttention)和前馈神经网络(FeedForward)组成,
    并通过残差连接和层归一化(LayerNorm)增强模型的性能。
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
        hidden_dim: 前馈神经网络的隐藏层维度。
    """
    def __init__(self, embed_dim, num_heads, hidden_dim):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.feed_forward = FeedForward(embed_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.self_attn(x)  # 多头自注意力机制 (1, 197, 768)
        x = x + attn_output  # 残差连接逐元素相加 (1, 197, 768)
        x = self.norm1(x)  # 层归一化 (1, 197, 768)
        ff_output = self.feed_forward(x)  # 前馈神经网络 (1, 197, 768)
        x = x + ff_output  # 残差连接逐元素相加 (1, 197, 768)
        x = self.norm2(x)  # 层归一化 (1, 197, 768)
        return x

多层堆叠

由多个相同的单编码器层堆叠而成,每一层都对输入特征进行进一步的抽象和提炼。随着层数的增加,模型能够学习到更加复杂和抽象的图像特征,逐步从简单的局部特征过渡到更具语义的全局特征,提高模型对图像的理解和分类能力。

# 完整的Transformer编码器
class TransformerEncoder(nn.Module):
    """
    TransformerEncoder 类实现了完整的 Transformer 编码器,由多个编码器层(EncoderLayer)堆叠而成。
    参数说明:
        num_layers: 编码器层数,表示堆叠的 EncoderLayer 数量。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
        hidden_dim: 前馈神经网络的隐藏层维度
    """
    def __init__(self, num_layers, embed_dim, num_heads, hidden_dim):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, hidden_dim) for _ in range(num_layers)])

    def forward(self, x):  # (1, 197, 768)
        for layer in self.layers:
            x = layer(x)  # 将输入张量依次通过每个 EncoderLayer,进行多头自注意力机制和前馈神经网络的处理
        return x  # (1, 197, 768)

分类头

在这里插入图片描述
添加分类标记:在编码器的最后,添加一个特殊的分类标记([CLS])。这个标记在输入时与其他图像小块的嵌入向量一起进入编码器,经过多层处理后,它融合了整个图像的信息,相当于对整幅图像的一个综合表示。
全连接层与分类:编码器输出的分类标记向量被输入到全连接层(分类头)中,全连接层的输出维度等于图像分类的类别数。通过对全连接层的输出进行 softmax 操作,得到每个类别的概率分布,从而确定图像所属的类别,实现图像分类的功能。

self.fc = nn.Linear(embed_dim, num_classes)  # 创建一个线性层,用于将编码器的输出映射到类别数

注意:在 TransformerImageClassifier 类中,分类头的全连接层没有接入 softmax 操作,主要是为了与 CrossEntropyLoss 等标准损失函数兼容,提高数值稳定性,并保持模型输出的灵活性。这种设计使得代码更加简洁,并且能够更好地适应不同的应用场景。

VIT模型总体代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

"""
代码说明:
1.PatchEmbedding类:负责将输入图像分块并嵌入到指定维度的向量空间。
2.PositionalEncoding类:为输入序列添加位置编码。
3.MultiHeadSelfAttention类:实现多头自注意力机制。
4.FeedForward类:前馈神经网络,用于对自注意力机制的输出进行进一步处理。
5.EncoderLayer类:单个编码器层,包含自注意力机制和前馈神经网络。
6.TransformerEncoder类:由多个编码器层堆叠而成的完整编码器。
7.TransformerImageClassifier类:完整的基于 Transformer 的图像分类模型,包含图像分块与嵌入、位置编码、编码器和分类头。
"""

# 图像分块与嵌入
class PatchEmbedding(nn.Module):
    """
    PatchEmbedding 类通过卷积操作实现了图像的分块与嵌入,为后续的 Transformer 编码器提供了合适的输入格式.
    参数说明:
        image_size: 输入图像的尺寸。
        patch_size: 图像分块的大小。
        in_channels: 输入图像的通道数。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
    """
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2  # 图像被分割为 (224 // 16) ** 2 = 196 个块
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)  # 每个块被映射到 768 维度的空间

    def forward(self, x):  # (1, 3, 224, 224)
        x = self.proj(x)  # (1, 768, 14, 14)
        x = x.flatten(2).transpose(1, 2)  # (1, 196, 768)
        return x


# 位置编码
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding 类通过向输入张量添加一个可学习的位置编码参数,使得模型能够感知到序列中各个元素的位置关系.
    位置编码的作用: 在 Transformer 模型中,自注意力机制(Self-Attention)本身并不包含对序列顺序的感知能力.
                    因此,需要通过位置编码将序列中元素的位置信息显式地加入模型.
    可学习的位置编码: 与固定的位置编码(如正弦/余弦函数生成的位置编码)不同,这里的 pos_embedding 是一个可学习的参数,会在训练过程中优化.
    参数说明:
        num_patches: 图像分块的数量。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
    """
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))  # 在 pos_embedding 的定义中,额外增加了一个位置(num_patches + 1),用于容纳 [CLS] token 的位置编码。

    def forward(self, x):  # (1, 197, 768)
        x = x + self.pos_embedding  # 将 self.pos_embedding 添加到输入张量 x 上,以注入位置信息 (1, 197, 768) 
        return x


# 多头自注意力机制
class MultiHeadSelfAttention(nn.Module):
    """
    多头自注意力机制的核心思想是通过计算输入序列中不同位置之间的关系,捕获全局依赖性.
    经过 MultiHeadSelfAttention 的前向传播后,输出张量的形状仍为 (1, 197, 768),但每个位置的特征向量已经包含了与其他位置的关系信息。
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
    """
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"  # 确保嵌入维度可以被头数整除,否则无法均匀分配每个头的维度
        self.embed_dim = embed_dim  # 嵌入维度
        self.num_heads = num_heads  # 多头自注意力的头数
        self.head_dim = embed_dim // num_heads  # 每个头的嵌入维度
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)  # 定义一个线性层,用于生成查询(Q)、键(K)和值(V)
        self.out_proj = nn.Linear(embed_dim, embed_dim)  # 定义一个线性层,用于将多头注意力的输出投影回原始嵌入维度

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()  # 输入张量的形状,表示批量大小、序列长度和嵌入维度(1, 196+1, 768)
        qkv = self.qkv_proj(x)  # 使用线性层生成 Q、K、V,形状为 (batch_size, seq_length, 3 * embed_dim)=(1, 197, 3 * 768)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim).permute(0, 2, 1, 3)  # 将 Q、K、V 分割为多个头,并调整形状为 (batch_size, num_heads, seq_length, head_dim)=(1, 12, 197, 3 * 64)
        q, k, v = qkv.chunk(3, dim=-1)  # 将 Q、K、V 分离,每个形状为 (batch_size, num_heads, seq_length, head_dim)=(1, 12, 197, 64)

        # 输入形状:q.shape = (1, 12, 197, 64) 和 k.T.shape = (1, 12, 64, 197)  输出形状:attn_scores.shape = (1, 12, 197, 197)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # 计算 Q 和 K 的点积,得到注意力分数矩阵,以及除以 sqrt(head_dim) 进行缩放,避免数值过大导致 softmax 梯度消失
        attn_probs = F.softmax(attn_scores, dim=-1)  # 对注意力分数应用 softmax,得到归一化的注意力概率=(1, 12, 197, 197)

        # 输入形状:attn_probs.shape = (1, 12, 197, 197) 和 v.shape = (1, 12, 197, 64)  输出形状:attn_output.shape = (1, 12, 197, 64)
        attn_output = torch.matmul(attn_probs, v)  # 据注意力概率对 V 进行加权求和,得到注意力输出
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, embed_dim)  # (1, 197, 768)
        output = self.out_proj(attn_output)  # 使用线性层将多头注意力的输出投影回原始嵌入维度  # (1, 197, 768)
        return output  


# 前馈神经网络
class FeedForward(nn.Module):
    """
    前馈神经网络的作用是对多头自注意力机制的输出进行进一步处理,增强模型的表达能力.
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        hidden_dim: 前馈神经网络的隐藏层维度。
    """
    def __init__(self, embed_dim, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)  # 将输入张量从嵌入维度映射到隐藏维度=(1, 197, 3072)
        x = self.relu(x)  # 应用 ReLU 激活函数,引入非线性
        x = self.fc2(x)  # 将隐藏维度的输出映射回嵌入维度=(1, 197, 768)
        return x


# 编码器层
class EncoderLayer(nn.Module):
    """
    EncoderLayer 类实现了 Transformer 编码器中的单个编码器层。每个编码器层由多头自注意力机制(MultiHeadSelfAttention)和前馈神经网络(FeedForward)组成,
    并通过残差连接和层归一化(LayerNorm)增强模型的性能。
    参数说明:
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
        hidden_dim: 前馈神经网络的隐藏层维度。
    """
    def __init__(self, embed_dim, num_heads, hidden_dim):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.feed_forward = FeedForward(embed_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.self_attn(x)  # 多头自注意力机制 (1, 197, 768)
        x = x + attn_output  # 残差连接逐元素相加 (1, 197, 768)
        x = self.norm1(x)  # 层归一化 (1, 197, 768)
        ff_output = self.feed_forward(x)  # 前馈神经网络 (1, 197, 768)
        x = x + ff_output  # 残差连接逐元素相加 (1, 197, 768)
        x = self.norm2(x)  # 层归一化 (1, 197, 768)
        return x


# 完整的Transformer编码器
class TransformerEncoder(nn.Module):
    """
    TransformerEncoder 类实现了完整的 Transformer 编码器,由多个编码器层(EncoderLayer)堆叠而成。
    参数说明:
        num_layers: 编码器层数,表示堆叠的 EncoderLayer 数量。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
        hidden_dim: 前馈神经网络的隐藏层维度
    """
    def __init__(self, num_layers, embed_dim, num_heads, hidden_dim):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, hidden_dim) for _ in range(num_layers)])

    def forward(self, x):  # (1, 197, 768)
        for layer in self.layers:
            x = layer(x)  # 将输入张量依次通过每个 EncoderLayer,进行多头自注意力机制和前馈神经网络的处理
        return x  # (1, 197, 768)


# 基于Transformer的图像分类模型
class TransformerImageClassifier(nn.Module):
    """
    TransformerImageClassifier 类是一个完整的基于 Transformer 的图像分类模型,包含了图像分块与嵌入、位置编码、编码器和分类头。
    参数说明:
        image_size: 输入图像的尺寸。
        patch_size: 图像分块的大小。
        in_channels: 输入图像的通道数。
        embed_dim: 嵌入维度,表示输入和输出的特征向量长度。
        num_heads: 多头自注意力机制的头数。
        hidden_dim: 前馈神经网络的隐藏层维度。
        num_layers: 编码器层数,表示堆叠的 EncoderLayer 数量。
        num_classes: 分类头的输出类别数。
    """
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, hidden_dim, num_layers, num_classes):
        super(TransformerImageClassifier, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)  # 创建 PatchEmbedding 实例,负责将输入图像分块并嵌入到指定维度的向量空间
        self.num_patches = self.patch_embedding.num_patches  # 计算图像分块的数量
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))  # 定义一个可学习的 [CLS] token,用于分类任务
        self.positional_encoding = PositionalEncoding(self.num_patches, embed_dim)  # 创建 PositionalEncoding 实例,为输入序列添加位置编码
        self.encoder = TransformerEncoder(num_layers, embed_dim, num_heads, hidden_dim)  # 创建 TransformerEncoder 实例,包含多个编码器层
        self.fc = nn.Linear(embed_dim, num_classes)  # 创建一个线性层,用于将编码器的输出映射到类别数

    def forward(self, x):  # (1, 3, 224, 224)
        x = self.patch_embedding(x)  # (1, 196, 768)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # (1, 1, 768)
        x = torch.cat((cls_tokens, x), dim=1)  # (1, 197, 768)
        x = self.positional_encoding(x)  # (1, 197, 768)
        x = self.encoder(x)  # (1, 197, 768)
        cls_output = x[:, 0]  # (1, 768)
        output = self.fc(cls_output)  # (1, 10)
        return output


# 示例使用
if __name__ == "__main__":
    # 模型参数设置
    image_size = 224   # 输入图像的尺寸
    patch_size = 16    # 图像分块的大小
    in_channels = 3    # 输入图像的通道数
    embed_dim = 768    # 嵌入维度,表示输入和输出的特征向量长度
    num_heads = 12     # 多头自注意力机制的头数
    hidden_dim = 3072  # 隐藏层维度,通常比 embed_dim 大得多,用于增加模型的表达能力
    num_layers = 12    # 编码器层数,表示堆叠的 EncoderLayer 数量
    num_classes = 10   # 分类头的输出类别数

    # 创建模型实例
    model = TransformerImageClassifier(image_size, patch_size, in_channels, embed_dim, num_heads, hidden_dim, num_layers,
                                       num_classes)

    # 生成随机输入
    input_tensor = torch.randn(1, 3, 224, 224)

    # 前向传播
    output = model(input_tensor)
    print(output.shape)  # torch.Size([1, 10])

参考链接
BatchNorm和LayerNorm——通俗易懂的理解
详解 Vision Transformer (ViT)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值