5.4.自注意力

自注意力

​ 在有了注意力机制后,我们将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为自注意力。

​ 给定一个由词元组成的输入序列 x 1 , ⋯   , x n x_1,\cdots,x_n x1,,xn,其中任意 x i ∈ R d x_i\in R^d xiRd,该序列的自注意力输出为一个长度相同的序列 y 1 , ⋯   , y n y_1,\cdots,y_n y1,,yn,其中:
y i = f ( x i , ( x 1 , x 1 ) , ⋯   , ( x n , x n ) ) ∈ R d y_i = f(x_i,(x_1,x_1),\cdots,(x_n,x_n))\in R^d yi=f(xi,(x1,x1),,(xn,xn))Rd
​ 函数 f f f是注意力函数吗,(query,(key,value),…)

在这里插入图片描述

在这里插入图片描述

import math
import torch
from torch import nn
from d2l import torch as d2l


num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()


batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape # 自注意力

1.位置编码

​ 与CNN、RNN不同,自注意力没有记录位置的信息,因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加位置编码来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。

​ 假设长度为n的序列是 x ∈ R n × d x\in R^{n\times d} xRn×d,那么使用位置编码矩阵 P ∈ R n × d P\in R^{n\times d} PRn×d来输出 X + P X+P X+P作为自编码输入, P P P的计算:
p i , 2 j = s i n ( i 1000 0 2 j d ) , p i , 2 j + 1 = c o s ( i 1000 0 2 j d ) p_{i,2j} = sin(\frac {i}{10000^{\frac {2j}d}}),p_{i,2j+1}=cos(\frac{i}{10000^{\frac{2j}d}}) pi,2j=sin(10000d2ji),pi,2j+1=cos(10000d2ji)
在这里插入图片描述

#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])


d2l.plt.show()

1.1 绝对位置信息

​ 在二进制表示中,较高比特位的交替频率低于较低比特位, 与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。 由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')


P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
d2l.plt.show()

在这里插入图片描述

1.2 相对位置信息

​ 位置 i + δ i+\delta i+δ处的位置编码可以线性投影位置 i i i处的位置编码来表示,记 w j = 1 1000 0 2 j d w_j =\frac {1}{10000^{\frac{2j}d}} wj=10000d2j1,则

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值