Transformer 核心代码multi-head self attention
transformer 编码器和解码器架构,主要的结构是multi-head self-attention,残差连接,FFN,层归一化。其中核心的代码是关于multi-head self-attention的实现,下面给出一个逐行代码解释。
首先回顾一下计算公式:
设输入序列为
x
1
:
T
∈
R
D
h
×
T
\mathbf{x}_{1:T} \in \mathbb{R}^{D_h \times T}
x1:T∈RDh×T,其嵌入表示并添加位置编码表示位置信息,有
H
=
[
e
x
1
+
p
1
,
⋯
,
e
x
T
+
p
T
]
\mathbf{H}=[e_{\mathbf{x}_1}+p_1,\cdots,e_{\mathbf{x}_T}+p_T]
H=[ex1+p1,⋯,exT+pT]
多头自注意力(Multi-HeadSelf-Attention),在多个不同的投影空间中捕捉不同的交互信息.假设在 M M M 个投影空间中分别应用自注意力模型,有
MultiHead ( Q , K , V ) = W o [ head 1 ; ⋯ ; head M ] , head m = self-att ( Q m , K m , V m ) , ∀ m ∈ { 1 , ⋯ , M } , Q m = W q m Q , K = W k m K , V = W v m V , \begin{aligned} \operatorname{MultiHead}(\mathbf{Q},\mathbf{K},\mathbf{V})&=\boldsymbol{W}_o[\operatorname{head}_1;\cdots;\operatorname{head}_M],\\ \operatorname{head}_m&=\operatorname{self-att}(\boldsymbol{Q}_m,\boldsymbol{K}_m,\boldsymbol{V}_m),\\ \forall m\in\{1,\cdots,M\},\quad \boldsymbol{Q}_m&=\boldsymbol{W}_q^m\boldsymbol{Q},\boldsymbol{K}=\boldsymbol{W}_k^m\boldsymbol{K},\boldsymbol{V}=\boldsymbol{W}_v^m\boldsymbol{V},\end{aligned} MultiHead(Q,K,V)headm∀m∈{1,⋯,M},Qm=Wo[head1;⋯;headM],=self-att(Qm,Km,Vm),=WqmQ,K=WkmK,V=WvmV,
其中 W o ∈ R D h × M ⋅ D v \boldsymbol{W}_o \in \mathbb{R}^{D_h \times M \cdot D_v} Wo∈RDh×M⋅Dv为输出投影矩阵, W q m ∈ R D k × D h \boldsymbol{W}_q^m \in \mathbb{R}^{D_k \times D_h} Wqm∈RDk×Dh, W k m ∈ R D k × D h \boldsymbol{W}_k^m \in \mathbb{R}^{D_k \times D_h} Wkm∈RDk×Dh, W q m ∈ R D v × D h \boldsymbol{W}_q^m \in \mathbb{R}^{D_v \times D_h} Wqm∈RDv×Dh, 为投影矩阵, m ∈ { 1 , . . . , M } m \in \{ 1,...,M\} m∈{1,...,M}.自注意力模型 self-att \operatorname{self-att} self-att可以看作在一个线性投影空间中建立 H \mathbf{H} H(自注意力中 Q , K , V 均为 H \mathbf{Q},\mathbf{K},\mathbf{V}均为\mathbf{H} Q,K,V均为H)中不同向量之间的交互关系.其计算公式为
self-att ( Q , K , V ) = V softmax ( K ⊺ Q D k ) , Q = W q H , K = W k H , V = W υ H , \begin{aligned}\operatorname{self-att}(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})&=\boldsymbol{V}\operatorname{softmax}\big(\frac{\boldsymbol{K}^\intercal\boldsymbol{Q}}{\sqrt{D_k}}\big),\\\boldsymbol{Q}&=\boldsymbol{W}_q\mathbf{H},\boldsymbol{K}=\boldsymbol{W}_k\mathbf{H},\boldsymbol{V}=\boldsymbol{W}_\upsilon\mathbf{H},\end{aligned} self-att(Q,K,V)Q=Vsoftmax(DkK⊺Q),=WqH,K=WkH,V=WυH,其中 D k D_k Dk是输入矩阵 Q \boldsymbol{Q} Q和 K \boldsymbol{K} K中列向量的维度, W q ∈ R D k × D h \boldsymbol{W}_q \in \mathbb{R}^{D_k \times D_h} Wq∈RDk×Dh, W k ∈ R D k × D h \boldsymbol{W}_k \in \mathbb{R}^{D_k \times D_h} Wk∈RDk×Dh, W v ∈ R D v × D h \boldsymbol{W}_v \in \mathbb{R}^{D_v \times D_h} Wv∈RDv×Dh为三个投影矩阵。
通过对公式的具体分析,在多头注意力机制的计算中,频繁涉及到key、query和value向量的构造,因此我们需要一个为多头注意力机制准备key、query和value向量的封装函数。
PrepareForMultiHeadAttention
这段代码定义了一个PyTorch模块,
PrepareForMultiHeadAttention
,其主要作用是为多头注意力机制准备key、query和value向量。这一过程涉及线性变换和将变换后的向量分割成指定数量的头。下面是对代码的逐行解释:类定义
class PrepareForMultiHeadAttention(nn.Module):
定义了一个名为PrepareForMultiHeadAttention
的类,该类继承自PyTorch的nn.Module
类。这表明PrepareForMultiHeadAttention
是一个可以集成到PyTorch模型中的自定义模块。构造函数
__init__
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
这个构造函数接受四个参数:d_model
是输入向量的维度;heads
是要分割的头的数量;d_k
是分割后每个头中向量的维度;bias
是一个布尔值,指定是否在线性变换中添加偏置项。self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
创建一个线性层,用于对输入向量进行线性变换。变换的输出维度是头的数量乘以每个头的维度。也就是对输入表示向量的维度进行切分,切成heads份,每份大小是d_k。self.heads = heads
和self.d_k = d_k
分别存储了头的数量和每个头中向量的维度,这些信息在前向传播时用于重塑变换后的向量。前向传播
forward
def forward(self, x: torch.Tensor):
定义了模块的前向传播逻辑,其中x
是输入张量,其形状可以是[seq_len, batch_size, d_model]
或[batch_size, d_model]
。head_shape = x.shape[:-1]
获取输入张量x
除最后一个维度外的形状,这用于后续重塑变换后的向量。x = self.linear(x)
对输入x
应用前面定义的线性变换。x = x.view(*head_shape, self.heads, self.d_k)
重塑线性变换后的向量,使其最后两个维度分别为头的数量和每个头的向量维度。这样,输出张量的形状变为[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_k]
。总结
这个模块在多头注意力机制中扮演关键角色,通过对key、query和value向量进行适当的线性变换和重塑,使得它们能够被分配到不同的“头”中。这种分配使得模型能够在不同的表示子空间中并行捕获信息,从而提高了模型处理复杂信息的能力。
输入:
d_model
:输入向量的维度;heads
:要分割的头的数量;d_k
:分割后每个头中向量的维度,即 d k = d model / heads d_k = d_{\text{model}}/\text{heads} dk=dmodel/heads;bias
:一个布尔值,指定是否在线性变换中添加偏置项。
transformer中,Q,K,V的维度大小通常是相等的,因此 d v = d k d_v = d_k dv=dk。
输出:输出张量的形状变为[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_k]
。
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker
class PrepareForMultiHeadAttention(nn.Module):
"""
## Prepare for multi-head attention
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
This is used to transform **key**, **query**, and **value** vectors.
该模块进行线性变换,并将向量分割为给定数量的头以进行多头注意力。这用于变换 **key**、**query** 和 **value** 向量。
"""
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
# Linear layer for linear transform
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
# Number of heads
self.heads = heads
# Number of dimensions in vectors in each head
self.d_k = d_k
def forward(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
# We apply the linear transformation to the last dimension and split that into
# the heads.
head_shape = x.shape[:-1]
# Linear transform
x = self.linear(x)
# Split last dimension into heads
x = x.view(*head_shape, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
return x
接着根据多头注意力的公式可以逐行实现Multi-Head Attention module,不过在这个过程中需要注意的是,实际计算的输入是[seq_len,batch_size,heads,d_k]
,因此在计算key和query的乘积时,使用的是torch.enisum()
方法,它可以为各种张量运算提供一个简洁的框架,具体解释参考官方文档
另外一个需要注意的点是在解码器中用到的masked multi-head attention,这是通过设置掩码来实现的。
Multi-Head Attention Module
这段代码定义了一个
MultiHeadAttention
类,它是Transformer模型中多头注意力机制的实现。以下是对该类及其方法的详细解释:类定义和初始化
class MultiHeadAttention(nn.Module):
定义了MultiHeadAttention
类,继承自PyTorch的nn.Module
。- 在
__init__
方法中,类接收几个参数:
heads
: 多头注意力机制中头的数量。d_model
: 输入向量的特征维度,也是query
、key
和value
向量的维度。dropout_prob
: dropout操作的概率,用于防止过拟合。bias
: 是否在PrepareForMultiHeadAttention
中使用偏置项。self.d_k = d_model // heads
计算每个头的特征维度。self.query
,self.key
,self.value
使用PrepareForMultiHeadAttention
类分别对query
、key
和value
向量进行线性变换,为多头注意力计算做准备。self.softmax
定义了在计算注意力时沿着键(key)的时间(列)维度应用的softmax函数。self.output
是一个线性层,用于将多头注意力的输出重新映射回原始的特征空间d_model
。self.dropout
定义了dropout操作,用于在注意力权重上进行。self.scale
是缩放因子,用于在计算softmax之前调整注意力分数,以避免因维度较大而导致的梯度消失或爆炸。
class MultiHeadAttention(nn.Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
"""
* `heads` is the number of heads.
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
"""
super().__init__()
# Number of features per head
self.d_k = d_model // heads
# Number of heads
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
# input:[d_m,heads,d_k]
# output:[seq_len, batch_size, heads, d_k]
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Softmax for attention along the time dimension of `key`
# 也就是沿着key的列方向进行softmax
self.softmax = nn.Softmax(dim=1)
# Output layer
self.output = nn.Linear(d_model, d_model)
# Dropout
self.dropout = nn.Dropout(dropout_prob)
# Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k)
# We store attentions so that it can be used for logging, or other computations if needed
self.attn = None
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Calculate scores between queries and keys
This method can be overridden for other variations like relative attention.
"""
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
# 计算K^T 与 Q的乘积
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
"""
`mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.
"""
assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
assert mask.shape[1] == key_shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
# Same mask applied to all heads.
mask = mask.unsqueeze(-1)
# resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
return mask
def forward(self, *,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`.
`mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`.
"""
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
seq_len, batch_size, _ = query.shape
if mask is not None:
mask = self.prepare_mask(mask, query.shape, key.shape)
# Prepare `query`, `key` and `value` for attention computation.
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
query = self.query(query)
key = self.key(key)
value = self.value(value)
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= self.scale
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = self.softmax(scores)
# Save attentions if debugging
tracker.debug('attn', attn)
# Apply dropout
attn = self.dropout(attn)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Save attentions for any other calculations
self.attn = attn.detach()
# Concatenate multiple heads
x = x.reshape(seq_len, batch_size, -1)
# Output layer
return self.output(x)
方法解释
- get_scores方法通过
torch.einsum
计算所有query
和所有key
的点积,得到注意力分数矩阵。
torch.einsum
是一个非常强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention),它提供了一种表达多维数组之间复杂操作的简洁方式。torch.einsum
接受一个操作字符串和若干个张量作为输入,操作字符串指定了输入张量的维度如何相乘和求和。在多头注意力的上下文中,
torch.einsum('ibhd,jbhd->ijbh', query, key)
这行代码执行了query矩阵和key矩阵的批量点乘操作,为了生成一个表示注意力分数的矩阵。下面是对这个操作字符串和相应操作的详细解释:torch.einsum(‘ibhd,jbhd->ijbh’, query, key)
- 输入张量:
query
和key
是两个张量,它们的维度都是[seq_len, batch_size, heads, d_k]
,其中:
seq_len
是序列长度,batch_size
是批量大小,heads
是注意力头的数量,d_k
是每个注意力头的维度。- 操作字符串:
'ibhd,jbhd->ijbh'
可以分解为三个部分:
ibhd
: 第一个张量(query)的维度标签。jbhd
: 第二个张量(key)的维度标签。ijbh
: 输出张量的维度标签。解析操作字符串
ibhd,jbhd
: 这表示query
和key
张量进行操作。每个张量的维度用不同的字母表示,相同的字母表示这些维度将进行点乘操作。
i
和j
分别代表query
和key
的序列长度维度。b
代表批量大小(两个张量共享这一维度)。h
代表头的数量(两个张量共享这一维度)。d
代表每个头内的特征或维度,query
和key
在这一维度上进行点乘。->ijbh
: 输出张量的维度。这里没有d
,因为d
维度上的元素被求和了(点乘后求和)。输出张量的维度是:
i
和j
分别代表query
的序列长度和key
的序列长度,这允许每个query
与所有key
进行比较,形成一个注意力分数矩阵。b
代表批量大小。h
代表头的数量。操作含义
这个操作计算了每个头内,每个
query
向量与每个key
向量的乘积,并将结果求和(因为维度d
没有出现在输出中)。这相当于计算注意力机制中的原始分数(未缩放的点乘注意力分数)。对于每个批次中的每个头,你会得到一个
[seq_len, seq_len]
的分数矩阵,表示序列中每个位置的query
如何与序列中每个位置的key
相互作用。这个分数矩阵接下来会被缩放、掩码处理(如果有的话),然后应用softmax函数来得到最终的注意力权重。
- prepare_mask方法用于处理掩码张量,使其适用于序列长度和头的维度。这在处理不同长度的序列时非常有用,可以阻止模型看到序列中的某些部分。也就是masked multi-head attention
- forward方法是执行多头注意力计算的主要函数。它首先调整
query
、key
、value
的形状以适应多头计算,然后计算注意力分数,应用缩放和掩码,最后通过softmax获取注意力权重。使用这些权重和value
计算加权和,最后通过输出层将结果映射回原始维度。
注意力计算过程
- 使用
PrepareForMultiHeadAttention
对query
、key
、value
进行线性变换并分头处理。- 计算
query
和key
的乘积,得到注意力分数矩阵。- 应用缩放因子。
- 如果提供了掩码,应用掩码。
- 对分数应用softmax函数,得到注意力权重。
- 应用dropout到注意力权重上。
- 使用注意力权重对
value
进行加权求和。- 将多个头的输出拼接并通过最后的线性层。
这个实现允许模型在不同的表示子空间中并行捕获信息,这是Transformer架构的关键特性之一,提高了处理复杂信息的能力。