机器学习笔记:注意力机制中多头注意力的实现

目录

介绍

模型

代码实现 

引入库

单个注意力头

多个注意力头的实现

测试

思考


介绍

在注意力机制中,单个注意力学到的东西有限,可以通过对不同的注意力进行组合,学到不同的知识,以达到想要的目的。因此采用”多头注意力“的方法进行实现,即有多个注意力”头“,对其进行连结得到输出。

模型

首先,对于我们输入的查询,以及每一个键值对,都有需要学习的一系列权重参数W,另外,注意力汇聚函数f也需要学习得到。 多头注意力的输出需要经过另一个线性转换, 它对应着h个头连结后的结果,因此这里也有一个参数需要进行学习。基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

代码实现 

引入库

首先引入深度学习相关的库。

import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

单个注意力头

这里,我们使用缩放点积注意力,先来对每一个注意力头进行实现。这里首先说明一点,即可以设定p_q=p_k=p_v=\frac{p_o}{h}。如果将查询、键和值的线性变换的输出数量设置为p_qh=p_kh=p_vh=p_o,则可以并行计算h个头。

[注,原文如此,但是我其实完全没有明白它这里在说什么,我不知道为什么这样设置。]

详解

这里定义一个多头注意力类,定义其头的数量,并定义隐藏层数量,以实现缩放点击注意力。在前向计算时,注意queries.shape=(batchSize,queryNum,numHiddens),key.shape=values.shape=(batchSize,k-vNum,numHiddens)。经过变换后,输出的queries,keys,values 的形状:  (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)

class MultiHeadAttention(nn.Block):
    """多头注意力"""
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)
            output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

多个注意力头的实现

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.transpose(0, 2, 1, 3)

    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

测试

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

输出结果和我们的想法是一样的:

(2, 4, 100) 

思考

  1. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

  • 9
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值