[DL]Attention机制解读

文章介绍了Attention机制在解决递归神经网络如RNN在长序列处理中的局限性,特别是在机器翻译任务中的应用。Attention机制的核心思想是通过注意力分数计算query与key的相似度,然后通过softmax归一化得到注意力权重,再与value进行加权求和得到最终表示。文章详细讨论了加性注意力和缩放点积注意力两种实现方式,并引出了Self-Attention和Multi-headSelf-Attention的概念,后者通过多个注意力头来捕捉序列内的不同依赖关系,提高了模型的表达能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

I-背景简介

在递归神经网络(RNN、LSTM等)已经被牢固确立为序列建模和转换问题(如机器翻译)的最先进方法背景下,大量的工作继续推动着递归语言模型和Encoder-Decoder架构的发展。

递归模型通常沿着输入和输出序列的符号位置进行计算,将位置与计算时间的步骤对齐,它们产生一连串的隐藏状态ht,位置t的ht取决于前一个隐藏状态ht-1和t的输入。这种固有的顺序性质排除了训练实例内部的并行化,这在较长的序列上是至关重要的,因为内存限制了跨实例的批处理。

Encoder-Decoder结构中,编码器负责将输入序列转换为一种中间表示,解码器则将该中间表示转换为输出序列。

问题引入

以机器翻译任务为例,在基于Encoder-Decoder结构的循环神经网络中,先通过一个Encoder循环神经网络读入所有的待翻译句子中的单词,得到一个包含原文所有信息的中间隐藏层,接着把中间隐藏层状态输入Decoder网络,一个词一个词的输出翻译句子。这样子无论输入中的关键词语有着怎样的先后次序,由于都被打包到中间层一起输入后方网络,Encoder-Decoder网络可以很好地处理这些词的输出位置和形式。
Alt
问题在于,中间状态由于来自于输入网络最后的隐藏层,一般来说它是一个大小固定的向量。既然是大小固定的向量,那么它能储存的信息就是有限的,当句子长度不断变长,由于后方的Decoder网络的所有信息都来自中间状态,中间状态需要表达的信息就越来越多。如果句子的信息是在太多,网络就难以把握学习。也就是说Encoder-Decoder网络存在容量上限,为了解决这个问题,提出了Attention机制。

在各种任务中,Attention机制允许对依赖环境进行建模,不考虑它们在输入输出序列中的距离。

Attention机制的心理学解释

  • 动物需要在复杂环境下有效关注值得注意的点
  • 心理学框架:人类根据刻意线索(主观能动)和无意线索(下意识)选择注意点

II-Attention机制原理

核心思想

  • 卷积、全连接、池化层都只考虑无意线索

  • Attention机制则考虑刻意线索

    • 刻意线索被称为查询(query)
    • 每个输入是一个值(value)和无意线索(key)的对
    • 通过注意力池化层来有偏向性的选择某些输入

核心逻辑就是「从关注全部到关注重点」。将有限的注意力集中在重点信息上,从而节省资源,快速获得最有效的信息。

发展历程

非参注意力池化层(不学习参数)

  • 给定数据(xi,yi),i=1,…,n。(xi,yi)即为key-value对

  • 平均池化是最简单的方案
    f ( x ) = 1 n ∑ i y i f(x)=\frac{1}{n}\sum_{i}y_i f(x)=n1iyi

  • 更好的方案是60年代提出来的Nadaraya-Waston核回归x即为query,xi,j为key,yi为value,K是衡量x与xi之间距离的函数

    f ( x ) = ∑ i = 1 n K ( x − x i ) ∑ j = 1 n K ( x − x j ) y i f(x)=\sum_{i=1}^n\frac{K(x-x_i)}{\sum_{j=1}^nK(x-x_j)}y_i f(x)=i=1nj=1nK(xxj)K(xxi)yi

    • 使用高斯核
      K ( u ) = 1 2 π e x p ( − u 2 2 ) K(u)=\frac{1}{\sqrt{2π}}exp(-\frac{u^2}{2}) K(u)=2π 1exp(2u2)
      那么
      f ( x ) = ∑ i = 1 n e x p ( − 1 2 ( x − x i ) 2 ) ∑ j = 1 n e x p ( − 1 2 ( x − x j ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i f(x)=\sum_{i=1}^n\frac{exp(-\frac{1}{2}(x-x_i)^2)}{\sum_{j=1}^nexp(-\frac{1}{2}(x-x_j)^2)}y_i=\sum_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i f(x)=i=1nj=1nexp(21(xxj)2)exp(21(xxi)2)yi=i=1nsoftmax(21(xxi)2)yi

参数化的注意力机制

  • 在之前基础上引入可以学习的参数w
    f ( x ) = ∑ i = 1 n s o f t m a x ( − 1 2 ( ( x − x i ) w ) 2 ) y i f(x)=\sum_{i=1}^nsoftmax(-\frac{1}{2}((x-x_i)w)^2)y_i f(x)=i=1nsoftmax(21((xxi)w)2)yi

注意力分数

注意力分数:query和key的相似度

注意力权重:注意力分数的softmax结果

f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i f(x)=\sum_{i=1}^nα(x,x_i)y_i=\sum_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i f(x)=i=1nα(x,xi)yi=i=1nsoftmax(21(xxi)2)yi

上述公式中,α(x,xi)为注意力权重,softmax函数中部分为注意力分数。(权重就是softmax后属于[0,1]之间的值,分数没有范围没有进行normalize,如下图所示)
Alt
扩展到高维度

假设query q∈Rq,m对key-value(k1v1),…,ki∈Rkvi∈Rv

注意力池化层:
f ( q , ( k 1 , v 1 ) , . . . ( k m , v m ) ) = ∑ i = 1 m α ( q , k i ) v i ∈ R v f(\mathbf{q},(\mathbf{k_1},\mathbf{v_1}),...(\mathbf{k_m},\mathbf{v_m}))=\sum_{i=1}^mα(\mathbf{q},\mathbf{k_i})\mathbf{v_i}\in{R^v} f(q,(k1,v1),...(km,vm))=i=1mα(q,ki)viRv

α ( q , k i ) = s o f t m a x ( a ( q , k i ) ) = e x p ( a ( q , k i ) ) ∑ j = 1 m e x p ( a ( q , k j ) ) ∈ R α(\mathbf{q},\mathbf{k_i})=softmax(a(\mathbf{q},\mathbf{k_i}))=\frac{exp(a(\mathbf{q},\mathbf{k_i}))}{\sum_{j=1}^mexp(a(\mathbf{q},\mathbf{k_j}))}\in{R} α(q,ki)=softmax(a(q,ki))=j=1mexp(a(q,kj))exp(a(q,ki))R

其中,a(q,ki)为注意力分数。

加性注意力(Additive Attention)

可学参数:Wk∈Rh×kWq∈Rh×qv∈Rh
a ( q , k ) = v T t a n h ( W k k + W q q ) a(\mathbf{q},\mathbf{k})=\mathbf{v}^Ttanh(\mathbf{W}_k\mathbf{k}+\mathbf{W}_q\mathbf{q}) a(q,k)=vTtanh(Wkk+Wqq)
等价于将query和key合并起来后放入一个隐藏大小为h输出为1的单隐藏层MLP。

tanh是一种双曲正切函数,是一种常用的激活函数,取值范围为[-1, 1],通常用于神经网络中。

tanh函数的数学表达式为:tanh(x) = (e^x - e^-x) / (e^x + e^-x)

缩放点积注意力(Scaled Dot-Product Attention)

如果query和key向量长度相同,q,ki∈Rd,那么
a ( q , k i ) = < q , k i > d a(\mathbf{q},\mathbf{k}_i)=\frac{<\mathbf{q},\mathbf{k}_i>}{\sqrt{d}} a(q,ki)=d <q,ki>
为了梯度稳定,向量维度越长会导致内积值越大,显然并不能将因向量维度长短导致的内积值作为判断相关性的依据,因此需要除以维度的影响。

缩放点积注意力相比于加性注意力在计算效率上更高,并且在实践中取得了较好的效果。通过缩放因子的引入,缩放点积注意力可以更好地控制注意力权重的范围,并提供更稳定的注意力计算。因此,在大多数情况下,缩放点积注意力是更常用和推荐的注意力机制。

注意力池化层

将注意力分数通过softmax进行归一化,与value进行点乘,将所有结果相加,得到最后的向量。(类似于加权求和)

向量化版本,Q∈Rn×dK∈Rm×dV∈Rm×v

注意力分数
a ( Q , K ) = < K T , Q , > d ∈ R n × m a(\mathbf{Q},\mathbf{K})=\frac{<\mathbf{K}^T,\mathbf{Q},>}{\sqrt{d}}\in{R}^{n×m} a(Q,K)=d <KT,Q,>Rn×m
注意力池化
f = V T s o f t m a x ( a ( Q , K ) ) ∈ R n × v f=\mathbf{V^T}softmax(a(\mathbf{Q},\mathbf{K}))\in{R}^{n×v} f=VTsoftmax(a(Q,K))Rn×v

III-Self-Attention

给定序列x1,…,xn,xi∈Rd,自注意力池化层,将xi当做query=key=value对序列进行抽取特征得到y1,…,yn,其中
y i = f ( x i , ( x 1 , x 1 ) , … … , ( x n , x n ) ) ∈ R d y_i=f(x_i,(x_1,x_1),……,(x_n,x_n))\in{R^d} yi=f(xi,(x1,x1),……,(xn,xn))Rd
自注意力没有记录位置信息,需要将位置信息注入到输入里,假设输入长度为n的序列X∈Rn×d,那么使用位置编码矩阵P∈Rn×d,使X+P作为自编码输入,位置矩阵P计算如下图所示。位置编码在自注意力机制中的作用是为序列中的每个位置提供位置信息。通过将位置编码添加到输入序列的嵌入表示中,模型可以学习到不同位置之间的相关性和依赖关系。
Alt

Self-Attention与Attention的区别

  • Attention通常应用于序列到序列的任务,例如机器翻译。在这种情况下,Attention可以输入序列的不同部分自适应的加权计算输入序列的信息,即将输入序列的每个位置与当前输出序列的每个位置进行比较,计算它们之间的相似度,并根据相似度计算的权重将所有输入序列信息进行加权相加,得到当前输出序列的位置的表示。
  • Self-Attention通常应用于序列内部信息的提取。在这种情况下,不需要额外的输入序列,而是直接对输入序列的每个位置进行比较加权,即将该输入序列每个位置转换为query=key=value三个向量,得到每个位置的表示。因此Self-Attention有助于模型学习序列内部的依赖关系,并实现并行计算,提高长序列处理的效率。

Self-Attention和Attention机制的主要区别就在于所计算的权重不同,Self-Attention机制只考虑输入序列中的内部关系,而Attention机制还需要考虑与另一个序列之间的关系。

IV-Multi-head Self-attention(Self-attention的变形 )

自注意力机制的缺陷就是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此提出了通过多头注意力机制来解决这一问题。

多头注意力进一步扩展了自注意力机制,它在计算权重时使用了多个注意力头(随机初始化),每个头都学习一组不同的权重,每个注意力头都可以学习不同的相关性模式。在计算聚合后的表示时,多头注意力会将每个头计算得到的表示进行拼接,然后通过一个线性变换来得到最终的表示。
Alt
需要注意的一点是,每个头只能和对应的头进行运算。比如,qi,2只能与ki,2,kj,2进行计算,而不能与ki,1,kj,1进行计算。得到bi,1,bi,2后,前馈层不需要两个矩阵,因此还需要通过附加的权重矩阵与拼接后的b矩阵进行相乘,得到最终的bi矩阵。
Alt

V-Attention代码

加性注意力(Additive Attention)

  1. 导入所需的库

    import math
    import torch
    from torch import nn
    from d2l import torch as d2l
    
  2. mask softmax操作

    # 通过在最后一个轴上遮蔽元素来执行softmax操作
    def masked_softmax(X, valid_lens):  # valid_lens:有效进行softmax的位数长度
        # X.shape=[batchsize, sequence_length, hidden_size],valid_lens:1或2维张量
        if valid_lens is None:
            return nn.functional.softmax(X, dim=-1)
        else:
            shape = X.shape
            if valid_lens.dim() == 1:
                valid_lens = torch.repeat_interleave(valid_lens, shape[1])  # 沿着张量的指定维度重复元素
            else:
                valid_lens = valid_lens.reshape(-1)  # 按行优先展开向量
            X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)  # 给无效位置赋极大负值,使其softmax值为0
            return nn.functional.softmax(X.reshape(shape), dim=-1)
    

    输入X与valid_lens的对应情况的三个例子:
    Alt
    Alt
    Alt

  3. 加性注意力

    class AdditiveAttention(nn.Module):
        """加性注意力"""
        def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
            super(AdditiveAttention, self).__init__(**kwargs)
            self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
            self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
            self.w_v = nn.Linear(num_hiddens, 1, bias=False)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, queries, keys, values, valid_lens):  # valid_lens:有效个key-value对
            queries, keys = self.W_q(queries), self.W_k(keys)  # querie.shape=[batchsize, num of queries, num_hiddens]
                                                                # keys.shape=[batchsize, num of keys, num_hiddens]
            features = queries.unsqueeze(2) + keys.unsqueeze(1)  # queries.unsqueeze(2).shape=[batchsize, num of queries, 1, num_hiddens]
                                                                # keys.unsqueeze(1).shape=[batchsize, 1, num of keys, num_hiddens]
                                                                # features.shape=[batchsize, num of queries, num of keys, num_hiddens]
            features = torch.tanh(features)  
            scores = self.w_v(features).squeeze(-1)  # self.w_v(features).shape=[batchsize, num of queries, num of keys, 1]
                                                    # scores.shape=[batchsize, num of queries, num of keys]
            self.attention_weights = masked_softmax(scores, valid_lens)
            return torch.bmm(self.dropout(self.attention_weights), values)  # torch.bmm按批次做矩阵乘法
                                                                            # 结果.shape=[batchsize, num of querier, num of values]
    
  4. 测试样例

    queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))  # queries.shape=[2, 1, 20], keys.shape=[2, 10, 2]
    values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)  # values.shape=[2, 10, 4]
    valid_lens = torch.tensor([2, 6])
    
    attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
    attention.eval()
    attention(queries, keys, values, valid_lens)  # .shape=[2, 1, 4]
    

缩放点积注意力(Scaled Dot-Product Attention)

  1. 导入所需的库

  2. mask softmax操作

  3. 缩放点积注意力

    class DotProductAttention(nn.Module):
        """缩放点积注意力"""
        def __init__(self, dropout, **kwargs):
            super(DotProductAttention, self).__init__(**kwargs)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, queries, keys, values, valid_lens=None):
            d = queries.shape[-1]
            scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
            self.attention_weights = masked_softmax(scores, valid_lens)
            return torch.bmm(self.dropout(self.attention_weights), values)
    
  4. 测试样例

    queries = torch.normal(0, 1, (2, 1, 2))
    attention = DotProductAttention(dropout=0.5)
    attention.eval()
    attention(queries, keys, values, valid_lens)
    

Attention Is All You Need 论文

参考笔记1

参考笔记2

参考讲解视频-李沐

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值