1、引言
在本文中,我们将探讨近两年来最具影响力的模型架构之一——Transformer模型。自从2017年Vaswani等人发表的论文《注意力是你所需要的全部》以来,Transformer架构在多个领域持续刷新着性能记录,尤其是在自然语言处理(NLP)领域。这种拥有庞大参数量的Transformer模型能够生成长篇且具有说服力的文章,为人工智能的新应用领域开辟了道路。鉴于Transformer架构的热潮在未来几年内似乎不会消退,理解其工作原理并亲自实现它变得至关重要。
尽管Transformer在自然语言处理领域取得了巨大成功,但我们在本教程中不会涉及NLP。原因有三:首先,人工智能的相关文章中有许多优秀的NLP课题研究,这些研究深入探讨Transformer架构在NLP中的应用。其次,GPT等大模型深入研究了自然语言级别的语言生成,读者可以轻松地将Transformer架构应用于其中。最后,也是最关键的,Transformer架构的应用远不止于此。虽然NLP是Transformer架构最初被提出并产生最大影响的领域,但它也推动了其他领域的研究进展,甚至包括计算机视觉。因此,我们将重点讨论Transformer和自注意力机制为何如此强大。在后续的文章中,我们将讨论Transformer在计算机视觉中的应用。
接下来,我们将导入一些标准库,同时,我们还会使用PyTorch Lightning作为辅助框架。如果你对PyTorch Lightning还不太熟悉,请仔细阅读前一篇博文。
# 导入标准库
import os # 用于操作系统功能
import numpy as np # 用于数值计算
import random # 用于生成随机数
import math # 用于数学运算
import json # 用于处理JSON数据
from functools import partial # 用于函数的部分应用
# 导入绘图相关的库
import matplotlib.pyplot as plt # 用于绘图
plt.set_cmap('cividis') # 设置颜色映射
%matplotlib inline # 使matplotlib图形在Jupyter Notebook中显示
from IPython.display import set_matplotlib_formats # 用于设置matplotlib的输出格式
set_matplotlib_formats('svg', 'pdf') # 设置输出格式为SVG和PDF,便于导出
from matplotlib.colors import to_rgb # 用于颜色转换
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0 # 设置线条宽度
import seaborn as sns # 用于数据可视化
sns.reset_orig() # 重置seaborn的默认设置
# 导入进度条库
from tqdm.notebook import tqdm # 用于显示加载进度条
# 导入PyTorch库
import torch # 用于深度学习
import torch.nn as nn # 用于神经网络
import torch.nn.functional as F # 用于神经网络函数
import torch.utils.data as data # 用于数据加载
import torch.optim as optim # 用于优化算法
# 导入Torchvision库
import torchvision # 用于计算机视觉
from torchvision.datasets import CIFAR100 # 用于加载CIFAR100数据集
from torchvision import transforms # 用于数据转换
# 导入PyTorch Lightning库
try:
import pytorch_lightning as pl # 用于简化训练过程
except ModuleNotFoundError: # 如果未安装PyTorch Lightning,则安装它
!pip install --quiet pytorch-lightning>=1.4
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # 用于监控学习率和保存模型
# 设置数据集和预训练模型的路径
DATASET_PATH = "../data" # 数据集路径
CHECKPOINT_PATH = "../saved_models/tutorial6" # 预训练模型保存路径
# 设置随机种子
pl.seed_everything(42) # 确保结果可复现
# 确保在GPU上的所有操作都是确定性的(如果使用GPU)
torch.backends.cudnn.deterministic = True # 确保CuDNN的确定性
torch.backends.cudnn.benchmark = False # 关闭CuDNN的基准测试
# 根据是否可用选择GPU或CPU
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device) # 打印设备信息
在准备运行代码之前,请务必确保已经根据实际需要调整了CHECKPOINT_PATH的路径,因为接下来将下载两个关键的预训练模型。如果此路径尚未设置或需要更新,请立即进行相应调整,以确保代码能够正确执行并加载这些重要的模型文件。
import urllib.request # 导入urllib.request库,用于处理URL请求
from urllib.error import HTTPError # 导入HTTPError,用于处理HTTP请求错误
# 定义存放预训练模型的GitHub URL
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/"
# 定义需要下载的预训练文件列表
pretrained_files = ["ReverseTask.ckpt", "SetAnomalyTask.ckpt"]
# 如果检查点路径不存在,则创建它
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
# 对于每个文件,检查它是否已经存在。如果不存在,尝试下载它。
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name) # 获取文件的完整路径
# 如果文件名中包含"/",则创建相应的文件夹结构
if "/" in file_name:
os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
# 如果文件不存在,则尝试下载
if not os.path.isfile(file_path):
file_url = base_url + file_name # 构建文件的URL
print(f"正在下载 {
file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path) # 尝试下载文件
except HTTPError as e:
# 如果下载过程中出现HTTP错误,打印错误信息
print("下载过程中出现问题。请尝试从GDrive文件夹下载文件,或联系作者并提供完整的输出信息,包括以下错误:\n", e)
2、Transformer架构
在本文中,我们将亲自动手实现Transformer模型架构。虽然这种架构已经广为人知,Pytorch已经提供了一个名为nn.Transformer的模块,并且有关如何利用它进行下一个词预测的教程也已存在。但为了深入理解其核心细节,我们仍将从头开始实现它。
当然,网络上关于注意力机制和Transformer的教程数不胜数。如果你对这一主题感兴趣,并希望在阅读完本文后获得不同的视角,以下是一些推荐资源:
- 《Transformer:一种用于语言理解的新型神经网络架构》(Jakob Uszkoreit, 2017) - 这是谷歌关于Transformer论文的原始博客文章,主要聚焦于机器翻译的应用。
- 《图解Transformer》(Jay Alammar, 2018) - 一篇非常受欢迎且易于理解的博客文章,通过许多精美的可视化图表直观地解释了Transformer架构。重点在于自然语言处理(NLP)。
- 《注意力机制解析》(Lilian Weng, 2018)- 一篇总结不同领域(包括视觉)中注意力机制的博客文章。
- 《自注意力机制图解》(Raimi Karim, 2019) - 对自注意力机制步骤的清晰可视化。如果你觉得下面的解释太过抽象,强烈推荐阅读。
- 《Transformer家族》(Lilian Weng, 2020) - 一篇非常详尽的博客文章,回顾了除了原始版本之外的更多Transformer变体。
2.1.注意力机制的定义
注意力机制是近年来在神经网络中引起广泛关注的一类新型网络层,特别是在处理序列任务时。尽管在学术文献中“注意力”有多种定义,但我们在这里采用以下定义:注意力机制指的是基于输入查询和元素键动态计算权重的元素加权平均。具体来说,这意味着什么?我们的目标是计算多个元素的特征平均值。不过,我们希望根据元素的实际值来赋予不同的权重,而不是平等地对待每个元素。换句话说,我们希望动态地决定哪些输入更值得“关注”。具体来说,注意力机制通常包括以下四个部分:
- 查询(Query):查询是一个特征向量,它描述了我们在序列中寻找的内容,即我们可能想要关注的对象。
- 键(Keys):对于每个输入元素,都有一个键,这也是一个特征向量。这个特征向量大致描述了元素“提供”的内容,或者它何时可能变得重要。键的设计应使我们能够根据查询识别出我们想要关注的元素。
- 值(Values):对于每个输入元素,我们还有一个值向量。这个特征向量是我们希望进行平均计算的。
- 得分函数(Score function):为了评估我们想要关注哪些元素,我们需要定义一个得分函数。得分函数以查询和键为输入,并输出查询-键对的得分或注意力权重。它通常通过简单的相似性度量实现,如点积或小型多层感知器(MLP)。
平均值的权重通过所有得分函数输出的softmax函数计算得出。因此,我们为那些与查询最相似的键对应的值向量赋予更高的权重。如果我们用伪代码来描述这个过程,可以这样写:
α i = exp ( f a t t n ( key i , query ) ) ∑ j exp ( f a t t n ( key j , query ) ) , out = ∑ i α i ⋅ value i \alpha_i = \frac{\exp\left(f_{attn}\left(\text{key}_i, \text{query}\right)\right)}{\sum_j \exp\left(f_{attn}\left(\text{key}_j, \text{query}\right)\right)}, \hspace{5mm} \text{out} = \sum_i \alpha_i \cdot \text{value}_i αi=∑jexp(fattn(keyj,query))exp(fattn(keyi,query)),out=i∑αi⋅valuei
在视觉上,我们可以将注意力机制对一系列单词的关注情况展示如下:

每个单词都对应着一个键向量和一个值向量。通过一个得分函数(在这个场景中是点积)将查询向量与所有键向量进行比较,以此确定各自的权重。为了简化,我们这里没有展示softmax过程。最终,所有单词的值向量会根据这些注意力权重进行加权平均。
大多数注意力机制的不同之处在于它们所采用的查询方式、键和值向量的定义,以及所使用的得分函数。在Transformer模型中应用的注意力机制被称为自注意力。在自注意力机制中,序列中的每个元素都充当键、值和查询的角色。对于序列中的每个元素,我们都会通过一个注意力层来评估其查询与其他所有元素键的相似度,并为每个元素生成一个经过加权平均的新值向量。接下来,我们将通过深入探讨Transformer模型中使用的特定注意力机制——缩放点积注意力——来进一步理解这一概念。
2.2.缩放点积注意力
自注意力机制的核心是缩放点积注意力。我们的目标是构建一种注意力机制,它允许序列中的任意元素都能高效地关注到其他任意元素。点积注意力机制的输入包括一组查询 Q Q Q、键 K K K 和值 V V V,其中 Q , K ∈ R T × d k Q, K \in \mathbb{R}^{T \times d_k} Q,K∈RT×dk 和 V ∈ R T × d v V \in \mathbb{R}^{T \times d_v} V∈RT×dv。这里 T T T 代表序列的长度, d k d_k dk 和 d v d_v dv 分别代表查询/键和值的隐藏维度。为了简化说明,我们这里暂时不考虑批量维度。元素 i i i 到 j j j 之间的注意力值是基于查询 Q i Q_i Qi 和键 K j K_j Kj 的相似度,使用点积作为相似度的度量方式。数学上,点积注意力的计算公式如下:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
这里的矩阵乘法 Q K T QK^T QKT 计算每一对查询和键的点积,生成一个 T × T T \times T T×T 形状的矩阵。每一行代表了特定元素对序列中所有其他元素的注意力得分。我们对这些得分应用softmax函数,并与值向量相乘,从而得到一个加权平均值(权重由注意力值决定)。下面是一个计算图,它从另一个角度展示了这种注意力机制(图示来源:Vaswani et al., 2017)。

我们尚未讨论的一点是 1 d k \frac{1}{\sqrt{d_k}} dk1 这个缩放因子的重要性。这个缩放因子对于在模型初始化后保持注意力值的适当方差非常关键。我们知道,初始化模型时我们希望各层的方差保持一致,因此查询 Q Q Q 和键 K K K 的方差也可能接近 1。但是,如果两个向量的方差为 σ 2 \sigma^2 σ2,那么它们点积的结果方差将会是 σ 4 ⋅ d k \sigma^4 \cdot d_k σ4⋅dk:
q i ∼ N ( 0 , σ 2 ) , k i ∼ N ( 0 , σ 2 ) → Var ( ∑ i = 1 d k q i ⋅ k i ) = σ 4 ⋅ d k q_i \sim \mathcal{N}(0,\sigma^2), k_i \sim \mathcal{N}(0,\sigma^2) \to \text{Var}\left(\sum_{i=1}^{d_k} q_i\cdot k_i\right) = \sigma^4\cdot d_k qi∼N(0,σ2),ki∼N(0,σ2)→Var(i=1∑dkqi⋅ki)=σ4⋅dk
如果不将方差重新缩放至 σ 2 \sigma^2 σ2 附近,那么在对数几率上应用 softmax 函数时,一个随机元素的值将会饱和至 1,而其他所有元素的值则会接近 0。这将导致通过 softmax 的梯度几乎为零,从而使得我们无法适当地学习模型参数。需要注意的是,方差中的额外 σ 2 \sigma^2 σ2 因子,即 σ 4 \sigma^4 σ4 而不是 σ 2 \sigma^2 σ2,通常不会造成问题,因为我们通常将原始方差 σ 2 \sigma^2 σ2 保持在接近 1 的水平。
在上图表中的掩码模块(mask)表示在注意力矩阵中对特定条目进行可选的掩码操作。例如,当我们将不同长度的多个序列堆叠到一个批次中时,就会使用这种方法。为了在 PyTorch 中仍然能够利用并行计算的优势,我们通常会将句子填充到相同的长度,并在计算注意力值时忽略这些填充标记。这通常是通过将相关的注意力对数几率设置为一个极低的值来实现的。
在详细讨论了缩放点积注意力块的细节之后,我们可以编写一个函数,该函数接受查询、键和值的三元组,并计算输出特征:
def scaled_dot_product(q, k, v, mask=None):
d_k = q.size(-1) # 获取查询向量的维度
attn_logits = torch.matmul(

最低0.47元/天 解锁文章
9万+

被折叠的 条评论
为什么被折叠?



