注意力机制 - 注意力汇聚:Nadaraya-Watson核回归

注意力汇聚:Nadaraya-Watson核回归

框架下的注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间交互形成了注意力汇聚,注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。在本节中,我们将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。1964年提出的Nadaraya-Watson核回归模型是⼀个简单但完整的例⼦,可以⽤于演⽰具有注意⼒机制的机器学习

import torch
from torch import nn
from d2l import torch as d2l

1 - 生成数据集

n_train = 50 # 训练样本数
x_train,_ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本
def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,)) # 训练样本的输出
x_test = torch.arange(0,5,0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数f(标记为“Truth”),以及学习得到的预测函数(标记为“Pred”)

def plot_kernel_reg(y_hat):
    d2l.plot(x_test,[y_truth,y_hat],'x','y',legend=['Truth','Pred'],xlim=[0,5],ylim=[-1,5])
    d2l.plt.plot(x_train,y_train,'o',alpha=0.5);

2 - 平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(),n_test)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j6TRUXSQ-1662988499736)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107812.svg)]

3 - 非参数注意力汇聚


# X_repeat的形状:(n_test,n_train)
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1,n_train))

# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2,dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights,y_train)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bHHxRgOe-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107813.svg)]

现在,我们来观察注意力的权重,这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知,“查询-键”对越接近,注意力汇聚的注意力权重就越高

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                 xlabel='Sorted training inputs',
                 ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d4R52D1O-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107814.svg)]

4 - 带参数注意力汇聚

批量矩阵乘法

X = torch.ones((2,1,4))
Y = torch.ones((2,4,6))

torch.bmm(X,Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值

weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
torch.bmm(weights.unsqueeze(1),values.unsqueeze(-1))
tensor([[[ 4.5000]],

        [[14.5000]]])

定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法,定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,),requires_grad = True))
        
    def forward(self,queries,keys,values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 /2 ,dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                        values.unsqueeze(-1)).reshape(-1)

训练

接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train,1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train,1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.5)
animator = d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[1,5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train,keys,values),y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5nsQsico-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107815.svg)]

如下所示,训练完带参数的注意力汇聚模型后,我们发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑

# keys的形状:(n_test,n_train),每⼀⾏包含着相同的训练输⼊(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V9IQPxat-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107816.svg)]

为什么新的模型更不平滑了呢?我们看一下输出结果的绘制图:与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                xlabel='Sorted training inputs',
                ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rRT7kL7O-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107817.svg)]

5 - 小结

  • Nadaraya-Watson核回归时具有注意力机制的机器学习范例
  • Nadaraya-Watson核回归的注意⼒汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于你将值所对应的键核查询作为输入的函数
  • 注意力汇聚可以分为非参数型核带参数型
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值