Transformer中的类别嵌入

类别嵌入

self.class_embedding = nn.Parameter(scale * torch.randn(width))

这一行代码的作用是在 VisionTransformer 类中创建并初始化一个类别嵌入向量(class embedding vector),用于表示输入序列的类别信息。

详细解释

类别嵌入

在 Transformer 模型中,类别嵌入(class embedding)是一种特殊的嵌入向量,通常用于表示整个输入序列的全局信息。它在视觉变压器(Vision Transformer)模型中起到了类似于 [CLS](分类)标记在 BERT 中的作用。

代码解释
self.class_embedding = nn.Parameter(scale * torch.randn(width))
1. torch.randn(width)
  • 作用:生成一个形状为 (width,) 的张量,其元素从标准正态分布(均值为0,标准差为1)中随机采样。
  • 示例:假设 width = 768,那么 torch.randn(width) 将生成一个包含 768 个元素的一维张量,每个元素为从标准正态分布中采样的随机数。
2. scale *
  • 作用:对生成的随机张量进行缩放。
  • scale 的定义:在代码中,scale 被定义为 width ** -0.5,即 1 / sqrt(width)。这个缩放因子通常用于标准化初始化值,使其具有适当的尺度。
  • 示例:假设 width = 768,那么 scale = 768 ** -0.5 = 1 / sqrt(768)
3. nn.Parameter(...)
  • 作用:将张量封装为一个 nn.Parameter 对象,这意味着这个张量将被视为模型的参数,会在训练过程中进行优化。
  • nn.Parameter 的使用nn.Parametertorch.Tensor 的子类,当其被赋值给 nn.Module(如 nn.Linearnn.Conv2d 等)的属性时,会被自动添加到模块的参数列表中,并在调用 model.parameters() 时返回。

综合解释

self.class_embedding = nn.Parameter(scale * torch.randn(width))
  • 生成随机张量:首先,通过 torch.randn(width) 生成一个形状为 (width,) 的随机张量,其元素从标准正态分布中采样。
  • 缩放张量:然后,将这个张量乘以 scale(即 width ** -0.5),以标准化其值。
  • 创建参数:最后,将这个缩放后的张量封装为一个 nn.Parameter 对象,并赋值给 self.class_embedding

作用

  • 类别嵌入self.class_embedding 是一个可训练的参数,用于表示输入序列的类别信息。
  • 模型训练:在模型训练过程中,self.class_embedding 会被优化,以适应训练数据,帮助模型更好地捕捉全局信息。

示例

假设我们有一个 width = 768VisionTransformer 模型实例:

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, width):
        super().__init__()
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))

# 创建一个 width 为 768 的 VisionTransformer 实例
vit = VisionTransformer(768)

# 打印 class_embedding 的形状和部分值
print(vit.class_embedding.shape)  # 输出: torch.Size([768])
print(vit.class_embedding[:5])  # 输出: 类别嵌入的前 5 个元素的值

这段代码会输出 class_embedding 的形状(应为 [768])和前 5 个元素的值。

总结

这一行代码 self.class_embedding = nn.Parameter(scale * torch.randn(width))VisionTransformer 类中创建并初始化了一个类别嵌入向量。该向量是一个可训练的参数,用于表示输入序列的全局信息,并在模型训练过程中进行优化。

为什么它用于表示输入序列的全局信息?

类别嵌入(class embedding)用于表示输入序列的全局信息是基于其设计和应用目的,尤其是在 Transformer 模型中的使用方法。让我们详细解释为什么类别嵌入在 Vision Transformer 中起到表示全局信息的作用。

类别嵌入的作用和背景

在 Transformer 模型中,类别嵌入(class embedding 或者 class token)是一种特殊的嵌入向量,通常添加到输入序列的开头。它的设计灵感来自于 BERT 模型中的 [CLS] token,用于捕获整个输入序列的全局信息,并且常用于分类任务。

Vision Transformer (ViT) 中的类别嵌入

在 Vision Transformer 模型中,输入图像首先被划分成若干个固定大小的图像补丁(patches),然后这些补丁被展平并嵌入到一个高维特征空间中。这些补丁嵌入被视为一个序列,类似于自然语言处理任务中的单词序列。

具体步骤
  1. 图像划分为补丁

    • 原始图像被划分为若干个不重叠的图像补丁。
    • 每个补丁被展平成一个向量,并通过线性变换(卷积操作)嵌入到高维空间。
  2. 添加类别嵌入

    • 类别嵌入(class embedding)被添加到补丁序列的开头。这个类别嵌入是一个可训练的参数,用于捕获整个序列的全局信息。
    • 序列的第一个位置(位置 0)被保留给类别嵌入,后续位置由图像补丁嵌入填充。
  3. 位置嵌入

    • 位置嵌入(positional embedding)被加到每个补丁嵌入和类别嵌入上,用于保留序列中的位置信息。
  4. 通过 Transformer 模块

    • 这个序列被传递给 Transformer 模块进行处理。Transformer 模块会对整个序列(包括类别嵌入和补丁嵌入)进行多层注意力和前馈神经网络的计算。
  5. 提取全局信息

    • 在处理完成后,类别嵌入会包含整个输入序列的全局信息,因为它通过多头自注意力机制与所有补丁嵌入进行了交互。
    • 通常在分类任务中,类别嵌入被传递给一个分类头(classification head)以进行最终的分类预测。

代码解释

在你的 VisionTransformer 代码中,这一行定义了类别嵌入:

self.class_embedding = nn.Parameter(scale * torch.randn(width))

在前向传播中,类别嵌入被添加到补丁序列的开头:

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
具体步骤解释
  1. 初始化类别嵌入

    • self.class_embedding 是一个形状为 (width,) 的可训练参数,用于表示类别嵌入。
  2. 添加到序列开头

    • 创建一个与输入张量 x 同 dtype 和设备的全零张量,其形状为 [batch_size, 1, width]
    • 将类别嵌入加到这个全零张量上,并与输入张量 x 在第一个维度上进行拼接。
    • 这样,类别嵌入被添加到每个输入样本的序列开头,形成一个新的张量,其形状为 [batch_size, grid ** 2 + 1, width]
  3. 通过 Transformer 模块

    • 类别嵌入与图像补丁嵌入一起传递给 Transformer 模块,进行多层自注意力计算。
    • Transformer 模块会将类别嵌入与所有补丁嵌入交互,从而使类别嵌入逐步聚合整个序列的信息。

总结

类别嵌入(class embedding)在 Vision Transformer 中用于表示输入序列的全局信息,因为它在前向传播中通过自注意力机制与所有图像补丁嵌入进行交互,从而捕获整个序列的全局特征。在分类任务中,这个类别嵌入最终用于做出分类决策,因此它被设计为包含全局信息。

添加类别嵌入

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

这行代码的作用是将类别嵌入(class embedding)添加到输入张量的序列开头。具体步骤如下:

  1. 创建一个类别嵌入向量并调整其数据类型和设备

    • self.class_embedding.to(x.dtype) 将类别嵌入向量的类型转换为输入张量 x 的数据类型。
    • torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) 创建一个形状为 [batch_size, 1, width] 的全零张量,其数据类型和设备与输入张量 x 一致。
    • self.class_embedding.to(x.dtype) + torch.zeros(...) 创建一个新的张量,其第一维度是类别嵌入,形状为 [batch_size, 1, width]
  2. 将类别嵌入添加到输入张量 x 的序列开头

    • x 是输入张量,其形状为 [batch_size, num_patches, width]
    • torch.cat([...], dim=1) 在第一个维度(即序列维度)上拼接类别嵌入和输入张量 x
  3. 更新后的张量形状

    • 拼接后的张量形状为 [batch_size, num_patches + 1, width],其中 num_patches + 1 表示原始的 num_patches 加上一个类别嵌入。

代码详解

让我们逐步分解这行代码:

1. 调整类别嵌入的数据类型
self.class_embedding.to(x.dtype)
  • self.class_embedding 是一个形状为 [width] 的向量。
  • .to(x.dtype) 将其数据类型转换为输入张量 x 的数据类型。
2. 创建一个全零张量
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  • x.shape[0] 是批次大小(batch size)。
  • 1 是类别嵌入的位置。
  • x.shape[-1] 是特征宽度(width)。
  • dtype=x.dtype 设置全零张量的数据类型与输入张量 x 一致。
  • device=x.device 设置全零张量的设备与输入张量 x 一致。
3. 添加类别嵌入
self.class_embedding.to(x.dtype) + torch.zeros(...)
  • 将类别嵌入向量与全零张量相加,形成一个形状为 [batch_size, 1, width] 的张量。
4. 拼接张量
torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
  • 使用 torch.cat 在第一个维度(序列维度)上拼接类别嵌入和输入张量 x
  • 最终形状为 [batch_size, num_patches + 1, width]

示例代码

以下是一个完整的示例,展示了上述步骤:

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, width):
        super().__init__()
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
    
    def forward(self, x: torch.Tensor):
        # 模拟输入张量 x,假设其形状为 [batch_size, num_patches, width]
        batch_size = x.shape[0]
        num_patches = x.shape[1]
        width = x.shape[2]
        
        # 添加类别嵌入
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(batch_size, 1, width, dtype=x.dtype, device=x.device), x], dim=1)
        return x

# 创建 VisionTransformer 实例
width = 768
vit = VisionTransformer(width)

# 创建一个模拟输入张量 x
batch_size = 2
num_patches = 196
x = torch.randn(batch_size, num_patches, width)

# 调用 forward 方法
output = vit(x)

# 输出张量形状
print(output.shape)  # 输出: torch.Size([2, 197, 768])

总结

这行代码通过在序列开头添加一个类别嵌入向量,将输入张量的形状从 [batch_size, num_patches, width] 扩展到 [batch_size, num_patches + 1, width]。这个类别嵌入向量在 Transformer 模型中用于捕获整个序列的全局信息,通常用于分类任务的最终决策。

什么是一个序列?

在计算机科学和数据处理的上下文中,一个序列(sequence)通常是指一个有序的元素集合,这些元素按照一定的顺序排列,可以是数字、字符、图像补丁等。在不同的应用场景中,序列的具体形式和内容可能有所不同,但它们都有一个共同点,即元素之间存在顺序关系。

序列的概念

一般概念
  • 序列 是一个元素的有序集合,元素可以是任何类型的数据(如数字、字符、图像块等)。
  • 顺序关系 是指序列中的元素按照特定的顺序排列,顺序信息通常是重要的,影响对序列的处理和理解。
具体示例
  1. 数值序列:如 1, 2, 3, 4, 5。这些数字按照从小到大的顺序排列。
  2. 字符序列:如 "hello"。字符 'h', 'e', 'l', 'l', 'o' 按照它们在字符串中的顺序排列。
  3. 图像补丁序列:如将一张图像划分为多个小的图像块,这些块按照它们在图像中的位置顺序排列。

在深度学习中的序列

在深度学习和神经网络中,序列数据是非常常见的输入类型。以下是几个典型的例子:

自然语言处理(NLP)

在 NLP 中,输入通常是单词序列或字符序列。比如:

  • 单词序列:一个句子 “The cat sat on the mat” 可以被表示为一个单词序列 ["The", "cat", "sat", "on", "the", "mat"]
  • 字符序列:同一个句子可以被表示为一个字符序列 ["T", "h", "e", " ", "c", "a", "t", " ", "s", "a", "t", " ", "o", "n", " ", "t", "h", "e", " ", "m", "a", "t"]
时间序列数据

在时间序列分析中,数据通常是按时间顺序排列的观测值。例如:

  • 股票价格序列:记录某只股票在不同时间点的价格。
  • 传感器数据序列:记录传感器在不同时间点的读数。
图像处理

在图像处理任务中,图像可以被划分为多个小块,每个小块作为序列中的一个元素进行处理。这种方法在 Vision Transformer (ViT) 中得到了应用。

Vision Transformer (ViT) 中的序列

在 Vision Transformer 中,图像被划分为固定大小的补丁(patches),这些补丁被视为一个序列来处理。具体步骤如下:

  1. 图像划分为补丁

    • 将输入图像划分为固定大小的非重叠补丁。例如,一个 224x224 的图像可以划分为 16x16 的补丁,总共有 (224/16)^2 = 196 个补丁。
  2. 将补丁展平并嵌入

    • 每个补丁被展平成一个向量,并通过线性变换嵌入到一个高维特征空间(例如,768 维)。
  3. 形成补丁序列

    • 将所有补丁嵌入向量连接起来,形成一个补丁嵌入序列,其形状为 [batch_size, num_patches, embedding_dim]
  4. 添加类别嵌入和位置嵌入

    • 类别嵌入(class embedding)被添加到序列的开头,用于捕获整个图像的全局信息。
    • 位置嵌入(positional embedding)被加到每个补丁嵌入上,以保留序列中的位置信息。

示例代码

以下是一个简单的 Vision Transformer 示例代码:

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, patch_size: int, stride_size: int, width: int, num_patches: int, output_dim: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)
        
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width))
        self.ln_pre = LayerNorm(width)
        self.transformer = Transformer(width, layers=12, heads=12)
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
    
    def forward(self, x: torch.Tensor):
        # 提取补丁嵌入
        x = self.conv1(x)  # [batch_size, width, num_patches_h, num_patches_w]
        x = x.flatten(2)   # [batch_size, width, num_patches]
        x = x.transpose(1, 2)  # [batch_size, num_patches, width]
        
        # 添加类别嵌入
        cls_embed = self.class_embedding.expand(x.shape[0], -1).unsqueeze(1)  # [batch_size, 1, width]
        x = torch.cat((cls_embed, x), dim=1)  # [batch_size, num_patches + 1, width]
        
        # 添加位置嵌入
        x = x + self.positional_embedding
        
        # 前向传播
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # 转换为 [sequence_length, batch_size, width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # 转换为 [batch_size, sequence_length, width]
        x = self.ln_post(x[:, 0, :])  # 取出类别嵌入
        
        # 投影到输出维度
        x = x @ self.proj
        
        return x

# 模拟输入图像
input_image = torch.randn(1, 3, 224, 224)  # [batch_size, channels, height, width]

# 创建 VisionTransformer 实例
vit = VisionTransformer(patch_size=16, stride_size=16, width=768, num_patches=(224 // 16) ** 2, output_dim=1000)

# 前向传播
output = vit(input_image)
print(output.shape)  # [1, 1000]

总结

序列是一个有序的元素集合,在深度学习中,序列数据(如文本、时间序列、图像补丁等)可以通过 Transformer 模型进行有效处理。Vision Transformer 通过将图像划分为补丁并视为序列,使得图像处理任务可以直接应用于序列建模技术,从而利用 Transformer 在捕捉长距离依赖关系方面的优势。

  • 13
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiruzhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值