自注意力
在有了注意力机制后,我们将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为自注意力。
给定一个由词元组成的输入序列
x
1
,
⋯
,
x
n
x_1,\cdots,x_n
x1,⋯,xn,其中任意
x
i
∈
R
d
x_i\in R^d
xi∈Rd,该序列的自注意力输出为一个长度相同的序列
y
1
,
⋯
,
y
n
y_1,\cdots,y_n
y1,⋯,yn,其中:
y
i
=
f
(
x
i
,
(
x
1
,
x
1
)
,
⋯
,
(
x
n
,
x
n
)
)
∈
R
d
y_i = f(x_i,(x_1,x_1),\cdots,(x_n,x_n))\in R^d
yi=f(xi,(x1,x1),⋯,(xn,xn))∈Rd
函数
f
f
f是注意力函数吗,(query,(key,value),…)
import math
import torch
from torch import nn
from d2l import torch as d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape # 自注意力
1.位置编码
与CNN、RNN不同,自注意力没有记录位置的信息,因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加位置编码来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。
假设长度为n的序列是
x
∈
R
n
×
d
x\in R^{n\times d}
x∈Rn×d,那么使用位置编码矩阵
P
∈
R
n
×
d
P\in R^{n\times d}
P∈Rn×d来输出
X
+
P
X+P
X+P作为自编码输入,
P
P
P的计算:
p
i
,
2
j
=
s
i
n
(
i
1000
0
2
j
d
)
,
p
i
,
2
j
+
1
=
c
o
s
(
i
1000
0
2
j
d
)
p_{i,2j} = sin(\frac {i}{10000^{\frac {2j}d}}),p_{i,2j+1}=cos(\frac{i}{10000^{\frac{2j}d}})
pi,2j=sin(10000d2ji),pi,2j+1=cos(10000d2ji)
#@save
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
d2l.plt.show()
1.1 绝对位置信息
在二进制表示中,较高比特位的交替频率低于较低比特位, 与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。 由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。
for i in range(8):
print(f'{i}的二进制是:{i:>03b}')
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
d2l.plt.show()
1.2 相对位置信息
位置 i + δ i+\delta i+δ处的位置编码可以线性投影位置 i i i处的位置编码来表示,记 w j = 1 1000 0 2 j d w_j =\frac {1}{10000^{\frac{2j}d}} wj=10000d2j1,则