精(李沐)多头注意力,代码理解

多头总框架 

多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

而多头self-attention模块,则是将Q,K,V通过参数矩阵映射后(即Q,K,V分别接一个全连接层),通过张量操作(X.reshape())将张量变换为可以实现多个头并行计算的样子,然后再做self-attention,将这个过程重复h(原论文中h=8)次,最后再将所有的结果拼接起来,再送入一个全连接层即可,图示如上:

代码解析:

 1、经过参数矩阵映射(即Q,K,V分别接一个全连接层):

self.W_q = nn.Linear(query_size,  num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size,    num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size,  num_hiddens, bias=bias)

2、通过张量操作将张量变换为可以实现多个头并行计算的样子(即下面的transpose_qkv函数) :

queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys =    transpose_qkv(self.W_k(keys),    self.num_heads)
values =  transpose_qkv(self.W_v(values),  self.num_heads)

3、然后再做self-attention:

output = self.attention(queries, keys, values, valid_lens)

4、将这个过程重复h(原论文中h=8)次(即上述的valid_lens实现重复h次):

# valid_lens 的形状: (batch_size,)或(batch_size,查询的个数)
if valid_lens is not None:
    # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
    valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

5、最后再将所有的结果拼接起来,再送入一个全连接层即可:

output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)

其中W_o是: 

self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

总体代码如下:

import torch
from torch import nn
from d2l import torch as d2l


class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        # 此处使用缩放点积注意力作为每一个注意力头
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,  num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size,    num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size,  num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # 经过变换后,输出的queries,keys,values 的形状: (batch_size * num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys =    transpose_qkv(self.W_k(keys),    self.num_heads)
        values =  transpose_qkv(self.W_v(values),  self.num_heads)

        # valid_lens 的形状: (batch_size,)或(batch_size,查询的个数)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

上述代码31行有“ output = self.attention(queries, keys, values, valid_lens) ” ,其实这一步便是完成以下公式的操作,所谓多头,只是多个attention同时计算罢了


通过张量操作实现多个头并行计算

基于适当的张量操作,可以实现多头注意力的并行计算

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数(transpose_qkv与transpose_output)。

具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#######################################################################
#### 为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。
#### 具体来说,transpose_output函数反转了transpose_qkv函数的操作。
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:(batch_size * num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

测试 

下面我们使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。


#### 下面我们使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_size,num_queries,num_hiddens)
num_hiddens, num_heads = 100, 5
# key_size, query_size, value_size,与num_hiddens相同
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
print(attention.eval())


batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) # (2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # (2,6,100)
print(attention(X, Y, Y, valid_lens).shape)

10.5. 多头注意力 — 动手学深度学习 2.0.0-beta0 documentation

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Pengsen Ma

太谢谢了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值