1. 前言
因果注意力机制(causal attention)是一种特殊的自注意力机制,其在计算context向量 z i z_i zi之前,会使用掩码(mask)屏蔽 x i + 1 , x i + 2 , ⋯ , x n x_{i+1}, x_{i+2}, \cdots, x_n xi+1,xi+2,⋯,xn,使注意力权重 α i , i + 1 , α i , i + 2 , ⋯ , α i n \alpha_{i,i+1}, \alpha_{i,i+2}, \cdots, \alpha_{in} αi,i+1,αi,i+2,⋯,αin全都等于0,生成的context向量 z i z_i zi只与 x 1 , x 2 , ⋯ , x i x_1, x_2, \cdots, x_i x1,x2,⋯,xi相关,模型只根据历史信息预测下一个token。
本文介绍因果注意力机制计算生成context向量
z
i
z_i
zi时屏蔽
x
i
+
1
,
x
i
+
2
,
⋯
,
x
n
x_{i+1}, x_{i+2}, \cdots, x_n
xi+1,xi+2,⋯,xn的两种mask方法,使用mask屏蔽因果注意力权重矩阵对注意力机制做Dropout,并实现继承自torch.nn.Module
的神经网络模块CausalAttention
。
2. 因果注意力掩码(Causal Attention Mask)
因果注意力机制计算context向量 z i z_i zi之前,会使用mask屏蔽 x i + 1 , x i + 2 , ⋯ , x n x_{i+1}, x_{i+2}, \cdots, x_n xi+1,xi+2,⋯,xn,使注意力权重 α i , i + 1 , α i , i + 2 , ⋯ , α i n \alpha_{i,i+1}, \alpha_{i,i+2}, \cdots, \alpha_{in} αi,i+1,αi,i+2,⋯,αin全都等于0,并重新归一化注意力权重 α i 1 , α i 2 , ⋯ , α i i \alpha_{i1}, \alpha_{i2}, \cdots, \alpha_{ii} αi1,αi2,⋯,αii,使 ∑ j = 1 i α i j = 1 \sum_{j=1}^i\alpha_{ij}=1 ∑j=1iαij=1。如下图所示,context向量 z 1 z_1 z1对应的注意力权重 α 11 = 1 \alpha_{11}=1 α11=1, α 12 ∼ α 1 n \alpha_{12}\sim\alpha_{1n} α12∼α1n均等于0,context向量 z 2 z_2 z2对应的注意力权重 α 21 + α 22 = 1 \alpha_{21}+\alpha_{22}=1 α21+α22=1, α 23 ∼ α 2 n \alpha_{23}\sim\alpha_{2n} α23∼α2n均等于0。依此类推,context向量 z i z_i zi对应的注意力权重 α i 1 + α i 2 + ⋯ + α i i = 1 \alpha_{i1}+\alpha_{i2}+\cdots+\alpha_{ii}=1 αi1+αi2+⋯+αii=1, α i , i + 1 ∼ α i n \alpha_{i,i+1}\sim\alpha_{in} αi,i+1∼αin均等于0。
2.1 注意力权重掩码(Mask Attention Weights)
最简单直接的因果注意力掩码方法是将注意力权重矩阵与mask矩阵(对角线及其下面元素值均为1的下三角矩阵)逐元素相乘,得到经过掩码的注意力权重矩阵(masked attention weights matrix),然后用经过掩码的注意力权重矩阵中每个元素除以其所在行全部元素的和,将未被mask的注意力权重重新归一化,使经过掩码的注意力权重每行所有注意力权重之和等于1。
可以使用如下代码实例化前文所述ScaledDotProductAttention
类对象,并计算注意力权重矩阵:
import torch
torch.manual_seed(123)
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
d_in = inputs.shape[1]
d_out = 2
class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
queries = self.W_query(x)
keys = self.W_key(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
sdpa = ScaledDotProductAttention(d_in, d_out)
queries = sdpa.W_query(inputs)
keys = sdpa.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)
执行上面代码,打印结果如下:
tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
[0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
[0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
[0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
[0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
[0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
grad_fn=<SoftmaxBackward0>)
使用PyTorch内置的tril
函数可以创建上述mask矩阵:
context_len = attn_scores.shape[0]
mask_matrix = torch.tril(torch.ones(context_len, context_len))
print(mask_matrix)
执行上面代码,打印结果如下:
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])
将注意力权重矩阵与mask矩阵逐元素相乘,得到经过掩码的注意力权重矩阵:
masked_attn_weights = attn_weights * mask_matrix
print(masked_attn_weights)
执行上面代码,打印结果如下:
tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
[0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
[0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
[0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
grad_fn=<MulBackward0>)
可以使用如下代码计算经过掩码的注意力权重矩阵各行所有元素之和,并用每个元素除以其所在行全部元素的和,将未被mask的注意力权重重新归一化,得到因果注意力权重矩阵:
row_sums = masked_attn_weights.sum(dim=1, keepdim=True)
masked_attn_weights_norm = masked_attn_weights / row_sums
print(masked_attn_weights_norm)
执行上面代码,打印结果如下:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
[0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
[0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
[0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
grad_fn=<DivBackward0>)
2.2 注意力分数掩码(Mask Attention Scores)
可以利用softmax
函数的数学特性用更少的计算步骤获得上述因果注意力权重矩阵。使用softmax
函数可以将注意力分数归一化得到注意力权重,假设某个注意力分数的值为负无穷
−
∞
-\infty
−∞,则相应注意力权重会等于0(因为
e
−
∞
=
0
e^{-\infty}=0
e−∞=0)。
将注意力分数矩阵与mask矩阵(对角线及其下面元素值均为0,对角线上面的元素值均为
−
∞
-\infty
−∞的矩阵)逐元素相加,得到经过掩码的注意力分数矩阵(masked attention scores matrix),然后将经过掩码的注意力分数矩阵输入softmax
函数,计算出上述因果注意力权重矩阵。
可以使用如下代码创建上述mask矩阵,并将注意力分数矩阵与mask矩阵逐元素相加,得到经过掩码的注意力分数矩阵:
mask = torch.triu(torch.ones(context_len, context_len), diagonal=1)
masked_attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked_attn_scores)
执行上面代码,打印结果如下:
tensor([[0.3111, -inf, -inf, -inf, -inf, -inf],
[0.1655, 0.2602, -inf, -inf, -inf, -inf],
[0.1667, 0.2602, 0.2577, -inf, -inf, -inf],
[0.0510, 0.1080, 0.1064, 0.0643, -inf, -inf],
[0.1415, 0.1875, 0.1863, 0.0987, 0.1121, -inf],
[0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
grad_fn=<MaskedFillBackward0>)
使用前文所述计算注意力权重方法,先将每个注意力分数除以key向量维度的平方根,再输入softmax
函数,得到因果注意力权重矩阵:
attn_weights = torch.softmax(masked_attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)
执行上面代码,打印结果如下:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
[0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
[0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
[0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
grad_fn=<SoftmaxBackward0>)
3. 注意力机制中的Dropout
Dropout是深度学习实践中已经被证明非常有效的一种防止模型过拟合方法,其在模型每次训练过程中均随机“丢弃”一定比例的隐藏层计算节点,被“丢弃”的隐藏层计算节点不参与训练过程,相应模型参数也不做更新。
如下图所示,对注意力机制做Dropout,可以将注意力权重矩阵与mask矩阵(部分元素值为0,其余元素值均为1的矩阵)逐元素相乘,得到经过Dropout的注意力权重矩阵,然后将经过Dropout的注意力权重矩阵中每个元素均乘以 1 1 − p \frac{1}{1-p} 1−p1,其中 p p p是随机“丢弃”的隐藏层计算节点的比例。
Dropout仅作用于模型训练阶段,将经过Dropout的注意力权重矩阵中每个元素均乘以 1 1 − p \frac{1}{1-p} 1−p1,可以使计训练和推理时算出的context向量数值大小保持一致。
可以使用PyTorch内置的torch.nn.Dropout
创建上述mask矩阵,并对因果注意力机制做Dropout:
dropout = torch.nn.Dropout(0.5)
dropped_attn_weights = dropout(attn_weights)
print(dropped_attn_weights)
执行上面代码,打印结果如下:
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.9665, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.6804, 0.0000, 0.0000, 0.0000],
[0.4889, 0.0000, 0.5085, 0.0000, 0.0000, 0.0000],
[0.3988, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3363]],
grad_fn=<MulBackward0>)
上面代码使用
torch.nn.Dropout
创建dropout对象时,指定了随机“丢弃”的隐藏层计算节点的比例 p p p为0.5。在深度学习实践中,并不会采用如此大的随机“丢弃”比例。 p p p越大,理论上模型越不容易出现过拟合,但是模型会越难收敛。在大语言模型中,一般会将 p p p设置为0.1或0.2。
4. 构建神经网络模块CausalAttention
神经网络模块CausalAttention
与前文所述ScaledDotProductAttention
非常相似。在__init__
方法中,CausalAttention
初始化了3个torch.nn.Linear
层,使用torch.nn.Dropout
初始化了一个用于对注意力机制做Dropout的对象,同时使用父类中的register_buffer
方法注册了一个名为mask
的对角线及其下面元素值均为0且上面元素值均为1的矩阵,用于实现注意力分数掩码。forward
方法计算出attn_scores之后,使用上述注意力分数掩码方法计算出因果注意力权重矩阵并进行Dropout。最后将经过Dropout的因果注意力权重矩阵与values相乘,输出context_vec。具体代码如下所示:
class CausalAttention(torch.nn.Module):
def __init__(self, d_in, d_out, context_len, dropout, qkv_bias=False):
super().__init__()
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = torch.nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_len, context_len), diagonal=1))
def forward(self, x):
batch_size, num_tokens, d_in = x.shape
queries = self.W_query(x)
keys = self.W_key(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
使用
register_buffer
方法可以在buffers中创建mask矩阵,buffers中的参数会随着模型一起自动移动到相应的设备(CPU或GPU)中,不需要手动将相应参数移动到模型所在设备,可以避免使用不同设备训练或推理时的设备不匹配错误。在深度学习实践中,训练大语言模型往往会同时输入一个batch的训练样本。前文实现的
ScaledDotProductAttention
并不支同时输入一个batch的样本数据。当forward
方法的输入x
包含一个batch的样本数据,x
将会是一个3D的张量,其中第1个维度为batch_size(输入batch中包含的训练样本总数),第2个维度表示num_tokens(每个训练样本包含的tokens总数),第三个维度表示embedding_dim(每个token对应的Embedding向量维度)。因此CausalAttention
类的forward
方法中使用transpose
函数对keys张量的第1和2维度进行转置。
可以使用如下代码实例化CausalAttention
类对象,并输入一个batch的样本数据,计算context向量:
batch_inputs = torch.stack((inputs, inputs), dim=0)
context_len = batch_inputs.shape[1]
ca = CausalAttention(d_in, d_out, context_len, 0.1, qkv_bias=False)
context_vecs = ca(batch_inputs)
print(context_vecs)
执行上面代码,打印结果如下:
tensor([[[ 0.0000, 0.0000],
[-0.0787, 0.2196],
[-0.3091, 0.3692],
[-0.2929, 0.3258],
[-0.2442, 0.2429],
[-0.2028, 0.2132]],
[[-0.1635, 0.4562],
[-0.2770, 0.3942],
[-0.2573, 0.2245],
[-0.2524, 0.2129],
[-0.2442, 0.2429],
[-0.2307, 0.2418]]], grad_fn=<UnsafeViewBackward0>)
因为使用了Dropout,相同inputs对应的注意力权重矩阵中随机“丢弃”的计算节点并不相同,因此打印输出的context_vecs中第一个inputs对应的context向量
context_vecs[0]
与第二个inputs对应的context向量context_vecs[1]
并不相同。
5. 结束语
因果注意力机制也被称作掩码注意力机制(masked attention),其使用mask矩阵屏蔽注意力权重矩阵或注意力分数矩阵对角线上面的元素,并计算因果注意力权重矩阵。注意力权重 α i 1 , α i 2 , ⋯ , α i i \alpha_{i1}, \alpha_{i2}, \cdots, \alpha_{ii} αi1,αi2,⋯,αii之和等于1, α i , i + 1 , α i , i + 2 , ⋯ , α i n \alpha_{i,i+1}, \alpha_{i,i+2}, \cdots, \alpha_{in} αi,i+1,αi,i+2,⋯,αin全都等于0,context向量 z i z_i zi只与 x 1 , x 2 , ⋯ , x i x_1, x_2, \cdots, x_i x1,x2,⋯,xi相关,模型只根据历史信息预测下一个token。
可以使用部分元素值为0且其余元素值为1的mask矩阵与因果注意力权重矩阵逐元素相乘,随机屏蔽因果注意力权重矩阵中的部分元素,对注意力机制做Dropout,防止模型过拟合。
万事俱备,是时候掀开多头注意力机制的面纱了!