地表最用心!Vision Transformer(ViT)数据流解析(代码同步)

文章详细介绍了ViT模型中的关键组件,如图像块嵌入到Transformer编码器的转换过程,位置和类别标记的引入,Dropout用于正则化,以及Transformer编码器如何通过多头自注意力和前馈网络处理序列信息。最后,文章涵盖了池化策略(CLSTokenPooling和MeanPooling)和分类头用于生成预测概率的流程。
摘要由CSDN通过智能技术生成

Vision Transformer
图像块嵌入 → 位置嵌入和类别标记 → Dropout → Transformer编码器 → 池化和分类

代码来自:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

1、图像块嵌入 to_patch-embedding

在这里插入图片描述
( 1 , 3 , 512 , 512 )图像 → ( 1 , 1024 , 256 )序列元素 \color{red}{(1,3,512,512)图像→(1,1024,256)序列元素} 13512512)图像11024256)序列元素

“嵌入”:是指将数据从原始形式(像素)转换成模型能够处理的高维空间的表示形式。嵌入能够捕捉到输入的数据特征,并将其投影到一个更适合模型学习和提取信息的空间中。

"""
Rearrange: 用于重新排列和变换输入张量的维度。
'b c (h p1) (w p2)'表示输入张量的维度,其中:b是批次大小。c是通道数,对于RGB图像,c是3。
(h p1)和(w p2)表示图像高度和宽度被分割成多个大小为p1和p2的块(patches)。如果p1=p2=16,这意味着每个16x16像素的区域会被处理为一个块。
-> b (h w) (p1 p2 c)指定了输出张量的维度排列方式:
b保持不变,仍然表示批次大小。
(h w)表示将原图像的高度和宽度维度合并,并且按照块(patch)的方式重新组织,每个块对应原图中的一个区域。这里假设高度和宽度完全可以被p1和p2整除,这样就可以得到整数个块。
(p1 p2 c)表示每个块被扁平化成一个向量,包括了所有像素和通道的信息。
对于每个16x16的块,如果c=3(RGB图像),则每个块扁平化后的大小是16*16*3=768。
"""
self.to_patch_embedding = nn.Sequential(
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
    nn.LayerNorm(patch_dim),
    nn.Linear(patch_dim, dim),
    nn.LayerNorm(dim),
)

(1)Rearrange层 → (1,1024,768)
假设 patch_height = patch_width = 16,意味着原图被分成32x32个16x16大小的图像块。
经过重排,输出的维度是(1,32x32,16x16x3(图像扁平化,将二维图像中的像素值展平成一个一维向量))=(1,1024,768)。

(2)第一个LayerNorm层 → (1,1024,768)
归一化:patch_dim = Patch_height * patch_width * channels

(3)Linear层 → (1,1024,256)
全连接层,将维度映射到目标维度dim设置为256。

(4)第二个LayerNorm层 → (1,1024,256)
同上,归一化,不改变输入维度。
(2)(3)(4)为结构图中所示“Linear Projection of Flattened Patches”
线性层的任务是将每个扁平化的补丁向量转换到模型预定的维度空间dim中。这是为了使每个补丁的表示与Transformer模型中的维度一致,并为后续的多头自注意力层做准备。

2、位置嵌入和类别标记pos_embedding、cls_token、dropout

在这里插入图片描述

( 1 , 1024 , 256 )序列张量 → ( 1 , 1025 , 256 )序列张量,序列中包括 1024 个图像补丁 + 1 个类别标记 \color{red}{(1,1024,256)序列张量→(1,1025,256)序列张量,序列中包括1024个图像补丁+1个类别标记} 11024256)序列张量11025256)序列张量,序列中包括1024个图像补丁+1个类别标记
Transformer允许模型学习输入数据的顺序。由于Transformer本身并不具备处理顺序信息的能力,因此需要通过位置嵌入来提供每个输入元素在序列中的位置信息。

(1)每个图像块(图中的0到9)都会通过一个线性层(在代码中由nn.Linear实现)转换为一个高维空间的表示。
同时,有一个额外的“CLS”嵌入(cls_token),它在代码中初始化并准备与图像块的嵌入合并。

(2)为每个图像块和“CLS”嵌入添加位置信息,这是通过将每个块的嵌入与对应的位置嵌入相加完成的(在代码中通过x += self.pos_embedding[:, :(n + 1)]实现)。

(3)最终的结果是一个序列,其中包含了位置信息的图像块嵌入,以及一个用于分类的“CLS”嵌入。

从图中可以看到:
图像块(Patches):输入图像被切割成多个图像块,这里以数字0-9进行标记,每个数字代表一个图像块的线性嵌入
类别嵌入([class]Embedding):在序列的最前边,有一个特殊的嵌入,通常被称为“类别”嵌入或者“CLS”嵌入(图中的0*)。这个嵌入是可学习的,旨在整个训练过程中捕获全局图像表示,最终用于分类任务。
位置嵌入(Position Embedding):每个图像块嵌入都会与一个位置嵌入向量相加。这些位置嵌入向量是可学习的,意味着它们会在训练过程中调整,以最佳方式表示块在原始图像中的位置信息。

 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

self.pos_embedding

  • 这是一个可训练的参数,其形状为 (1, num_patches + 1, dim)。
  • num_patches 是图像分割后得到的补丁数。在前面的步骤中,对于 512x512 图像和 16x16 的补丁大小,我们得到了 1024 个补丁。
  • +1 是为了考虑额外的类别标记(cls_token),它会被添加到序列的开头,用于分类任务的最终预测。
  • dim 是模型的维度,与前面线性层的输出维度相匹配。
  • nn.Parameter 表示这是一个可训练的参数,模型在训练过程中会调整这些值。
 self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

self.cls_token

  • 也是一个可训练的参数,形状为 (1, 1, dim)。
  • 它表示类别标记,它的目的是代表整个图像的全局信息。
  • 这个 cls_token 会被复制并添加到所有补丁的前面。
def forward:
	cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
	x = torch.cat((cls_tokens, x), dim=1)
	x += self.pos_embedding[:, :(n + 1)]
	x = self.dropout(x)

主要步骤:
(1)扩展 cls_token:
使用 repeat()cls_token 被复制以匹配批次大小 b,所以它可以与每个图像的补丁序列一起使用。

(2)连接 cls_token:
使用 torch.cat,将扩展后的 cls_token 添加到补丁嵌入的序列开头。

(3)添加位置嵌入:
通过广播机制将位置嵌入添加到每个补丁嵌入上。因为 pos_embedding 的第一个维度是1,所以它可以自动扩展到整个批次。

(4)应用 Dropout:
self.dropout 可能会应用于提高训练过程中的泛化能力,但它不改变数据的形状。

3、Transformer Decode

( 1 , 1025 , 256 )序列张量 → ( 1 , 1025 , 256 )序列张量 ,每个元素的特征融合了整个图像的信息 \color{red}{(1,1025,256)序列张量→(1,1025,256)序列张量},每个元素的特征融合了整个图像的信息 11025256)序列张量11025256)序列张量,每个元素的特征融合了整个图像的信息
decode

(1)多头自注意力 (Multi-Head Attention):

  • Attention 类实现了多头自注意力机制,其中包含了将输入分割成多个头、进行自注意力运算,并最终合并结果的过程。
  • Attention模块中,输入x首先通过层归一化,然后是线性层生成查询(query)、键(key)和值(value)。
  • 接下来是自注意力运算,其中包括了缩放点积注意力和softmax操作。
  • 最后,将注意力的输出通过另一个线性层进行投影,如果需要的话,还有一个dropout操作。

(2)残差连接:
在自注意力模块之后,将自注意力的输出与输入x相加,形成一个残差连接。这有助于避免在深层网络中出现梯度消失的问题。

(3) 前馈网络 (Feed Forward Network):

  • FeedForward 类实现了前馈网络。它是由两个线性层和一个激活函数(GELU)组成的序列。
  • 和自注意力模块一样,前馈网络的输出也通过残差连接与输入相加。

(4)叠加多个编码器层:

  • depth 参数指定了要堆叠多少个这样的编码器层。这样,输入x就会按顺序通过这些层进行处理。

在每个编码器层中,数据的形状不会改变。如果输入的形状是 (1, 1025, 256),经过每一层之后,它仍然保持不变。残差连接和层归一化帮助模型在这些层之间平滑地传递信息,并有利于训练深层网络。

在整个Transformer编码器处理结束后,我们得到的输出张量仍然是 (1, 1025, 256) 的形状,但每个补丁的表示现在都富含了与其他补丁的关系信息。这个输出随后将用于分类任务,通常是通过取第一个元素(也就是与cls_token对应的表示)来进行最终的预测。

4、池化(CLS

( 1 , 1024 , 256 )序列张量 → ( 1 , 256 )张量 \color{red}{(1,1024,256)序列张量→(1,256)张量} 11024256)序列张量1256)张量
ViT中可选择两种池化策略:
(1)CLS Token Pooling: 使用特别的类别标记(cls_token)作为整个图像的表示。这是Transformer在NLP任务中常用的策略,比如BERT模型用于句子分类任务。在图像分类的上下文中,cls_token是编码器的第一个输出,并被用作图像的全局表示。
(2)Mean Pooling: 计算编码器输出的序列维度上的平均值,得到一个整体的图像表示。这意味着将所有补丁的表示平均化,来作为图像的表示。
在代码中forward中,这两种池化策略通过一个条件语句进行选择:

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

池化(Pooling)是在神经网络中常用的降维操作,目的是减小数据的空间维度(即高度和宽度),从而减少计算量和避免过拟合,同时也尝试保留重要的信息。在Vision Transformer(ViT)中,池化操作用于聚合Transformer编码器的输出,为分类头(即最后的线性层)提供一个固定大小的表示。

在ViT模型中,有两种常见的池化策略:

  1. CLS Token Pooling: 使用特别的类别标记(cls_token)作为整个图像的表示。这是Transformer在NLP任务中常用的策略,比如BERT模型用于句子分类任务。在图像分类的上下文中,cls_token是编码器的第一个输出,并被用作图像的全局表示。

  2. Mean Pooling: 计算编码器输出的序列维度上的平均值,得到一个整体的图像表示。这意味着将所有补丁的表示平均化,来作为图像的表示。

在代码中,这两种池化策略通过一个条件语句选择:

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
  • 如果self.pool属性被设置为'mean',则代码执行x.mean(dim = 1)。这会计算x张量在维度1上的均值(即所有补丁的均值),结果是每个样本只有一个向量表示。如果输入x的形状是(1, 1025, 256),那么执行平均池化后,输出的形状将是(1, 256)

  • 如果self.pool不是'mean',则假设是使用CLS Token,代码执行x[:, 0]。这会选取x中每个样本的第一个元素(即CLS Token的输出),也是为了每个样本得到一个向量表示。对于输入形状为(1, 1025, 256)x,选取CLS Token后,输出的形状同样为(1, 256)
    无论哪种池化策略,得到的输出都是一个固定长度的向量,可以被直接用于分类任务。这个向量随后会被传递到模型的最后一个线性层,进行最终的类别预测。

5、分类头(Classifier Head

在这里插入图片描述

( 1 , 256 )张量 → ( 1 , n u m c l a s s e s )张量,类型为概率分布,表示不同类别的预测概率 \color{red}{(1,256)张量→(1,num classes)张量,类型为概率分布,表示不同类别的预测概率} 1256)张量1numclasses)张量,类型为概率分布,表示不同类别的预测概率

self.mlp_head = nn.Linear(dim, num_classes)

self.mlp_head是一个线性层,dim是Transformer编码器的输出维度(也就是每个补丁的嵌入维度),而num_classes是模型需要预测的类别数。这个线性层的作用是将编码器的输出(或者说是池化后的表示)从dim维映射到num_classes维,每一维对应一个特定的类别的得分。

def forward
	x = self.to_latent(x)
	x = self.mlp_head(x)
  • self.to_latent(x)通常是一个恒等操作(nn.Identity()),意味着它不会改变输入。其存在是为了提供一个扩展点,以便在必要时可以加入额外的处理步骤。
  • self.mlp_head(x)应用了上述的线性映射。

如果我们从池化步骤得到的输出形状是(1, 256)(假设dim=256),那么应用self.mlp_head之后,输出形状将变为(1, num_classes)num_classes是模型被训练识别的类别数量。例如,如果是一个1000类分类问题,输出形状将是(1, 1000)
最终输出的每一个维度代表模型对于相应类别的预测置信度。在实际应用中,这些置信度通常通过softmax函数转换为概率,以便更直观地表示模型对各个类别的预测概率。然而,softmax操作通常包含在损失函数中(例如,使用nn.CrossEntropyLoss时),而不是显式地作为模型的一部分。

  • 14
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码是用于实现Vision Transformer框架的一部分功能,具体逐行解析如下: 1. `conv_output = F.conv2d(image, kernel, stride=stride)`: 这一行代码使用PyTorch中的卷积函数`F.conv2d`来对输入图像进行卷积操作。 2. `bs, oc, oh, ow = conv_output.shape`: 这一行代码通过`conv_output.shape`获取卷积输出张量的形状信息,其中`bs`表示批次大小,`oc`表示输出通道数,`oh`和`ow`分别表示输出张量的高度和宽度。 3. `patch_embedding = conv_output.reshape((bs, oc, oh*ow))`: 这一行代码通过`reshape`函数将卷积输出张量进行形状变换,将其转换为形状为`(bs, oc, oh*ow)`的张量。 4. `patch_embedding = patch_embedding.transpose(-1, -2)`: 这一行代码使用`transpose`函数交换张量的最后两个维度,将形状为`(bs, oh*ow, oc)`的张量转换为`(bs, oc, oh*ow)`的张量。 5. `weight = weight.transpose(0, 1)`: 这一行代码将权重张量进行转置操作,交换第0维和第1维的位置。 6. `kernel = weight.reshape((-1, ic, patch_size, patch_size))`: 这一行代码通过`reshape`函数将权重张量进行形状变换,将其转换为形状为`(outchannel*inchannel, ic, patch_size, patch_size)`的张量。 7. `patch_embedding_conv = image2emb_conv(image, kernel, patch_size)`: 这一行代码调用了`image2emb_conv`函数,并传入了图像、权重张量和补丁大小作为参数。 8. `print(patch_embedding_conv.shape)`: 这一行代码打印了`patch_embedding_conv`的形状信息。 以上是对Vision Transformer代码的逐行解析
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值