自注意力
- 给定序列
- 自注意力池化层将当做key,value,query来对序列抽取特征得到
是自注意力池化函数
位置编码
- 没有记录位置信息
- 位置编码将位置信息注入到输入里
- 假设长度为n的序列是,那么使用位置编码矩阵来输出作为自编码输入
- P的元素如下计算:
位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。 第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。
#@save
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
在位置嵌入矩阵P中, 行代表词元在序列中的位置,列代表位置编码的不同维度。
绝对位置信息
每个数字、每两个数字和每四个数字上的比特值 在第一个最低位、第二个最低位和第三个最低位上分别交替。
在二进制表示中,较高比特位的交替频率低于较低比特位, 与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。 由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。
相对位置信息
对于任何确定的位置偏移δ,位置i+δ处 的位置编码可以线性投影位置i处的位置编码来表示。
这种投影的数学解释是,令ωj=1/100002j/d, 对于任何确定的位置偏移δ, 任何一对 (pi,2j,pi,2j+1)都可以线性投影到 (pi+δ,2j,pi+δ,2j+1):
小结
- 自注意力池化层将xi当做key,value,query来对序列抽取特征
- 完全并行、最长序列为1、但对长序列计算复杂度高
- 位置编码在输入中加入位置信息,使得自注意力能够记忆当前位置信息