李宏毅机器学习_各种各样神奇的自注意力机制(Self-attention)

目录

摘要

ABSTRACT

一、How to make self-attention efficient?

二、Local Attention

三、Stride Attention

四、Global Attention

五、三种方法的比较

一、(Local Attention)

二、(Stride Attention)

三、(Global Attention)

六、Can we only focus on Critical Parts?

七、Clustering

一、什么是Clustering

八、Learnable Patterns

 九、数学推导

十、代码示例

一、数据准备

二、 初始化权重

三、导出key, query and value的表示

四、计算输入的attention scores

五、计算softmax 

六、将attention scores乘以value

七、对加权后的value求和以得到输出

总结


摘要

本周学习了三种不同的自注意力机制,Local Attention 专注于输入序列的一个局部窗口,使得模型在处理一个特定元素时,只关注其附近的其他元素。相对地,Global Attention 考虑输入序列的全部元素,为每个元素分配不同的权重。Stride Attention 则是介于两者之间,它按照一定的步长(stride)选择性地关注输入序列的元素,从而在保证模型性能的同时,降低了计算复杂性。

ABSTRACT

This week, three distinct self-attention mechanisms were studied. Local Attention focuses on a local window of the input sequence, enabling the model to concentrate on nearby elements when handling a specific item. In contrast, Global Attention considers the entire input sequence, assigning differing weights to each element. Stride Attention lies between the two, selectively attending to elements of the input sequence at a certain stride. This approach ensures model performance while reducing computational complexity.

一、How to make self-attention efficient?

假设输入序列(query)长度是 N,为了捕捉每个 value 或者 token 之间的关系,需要对应产生 N 个 key 与之对应,并将  query 与 key 之间做 dot-product,就可以产生一个 Attention Matrix(注意力矩阵),维度 N*N。这种方式最大的问题就是当序列长度太长的时候,对应的 Attention Matrix 维度太大,会给计算带来麻烦,计算时间会很长。

24f2d2a2589bbd8d723286c6b7d18d9e.png

对于 transformer 来说,self-attention 只是大的网络架构中的一个 module。由上述分析我们知道,对于 self-attention 的运算量是跟 N 的平方成正比的。当 N 很小的时候,单纯增加 self-attention 的运算效率可能并不会对整个网络的计算效率有太大的影响,比如说,如果在Feed Forward时需要的计算量大,如果去增加self-attention的运算效率,并不能使其速率变快。因此,提高 self-attention 的计算效率从而大幅度提高整个网络的效率的前提是 N 特别大的时候,比如做图像识别(影像辨识、image processing)。

05b930cabfc1131671a3d8677b43cb84.png

二、Local Attention

如何加快 self-attention 的速度呢?根据上述分析可以知道,影响 self-attention 效率最大的一个问题就是 Attention Matrix 的计算。如果我们可以根据一些人类的知识或经验,选择性的计算 Attention Matrix 中的某些数值或者某些数值不需要计算就可以知道数值,理论上可以减小计算量,提高计算效率。举个例子,比如我们在做文本翻译的时候,有时候在翻译当前的 token 时不需要给出整个 sequence,其实只需要知道这个 token 两边的邻居,就可以翻译的很准,也就是做局部的 attention(local attention)。这样可以大大提升运算效率,但是缺点就是只关注周围局部的值,这样做法其实跟 CNN 就没有太大的区别了,毕竟CNN也是关注一小部分的数据。

bc1916669a19f323767ebf451e00c21c.png

三、Stride Attention

当然,我们可能会认为只是考虑邻居的话,考虑的范围太窄,所以我们想到考虑稍微大一点的范围,比如说考虑隔两个距离的这一批数据,就是在翻译当前 token 的时候,给它空一定间隔(stride)的左右邻居,从而捕获当前与过去和未来的关系。当然stride的数值可以自己确定。

c2b9e9442474322f2f84d8578d57fff9.png

四、Global Attention

还有一种 global attention 的方式,就是选择 sequence 中的某些 token 作为 special token(比如标点符号),或者在原始的 sequence 中增加 special token。让 special token 与序列产生全局的关系,但是其他不是 special token 的 token 之间没有 attention。以在原始 sequence 前面增加两个 special token 为例:

c0a550d66416b9429d82fe6220270901.png

五、三种方法的比较

一、(Local Attention)

局部注意力是一种只关注输入序列中的一部分(即局部范围)的注意力机制。在这种情况下,模型不是对所有可能的位置进行注意,而是只对输入序列中的一个小的、连续的部分进行注意。这种方法的优点是计算效率更高,因为它减少了需要处理的位置的数量。然而,它的缺点是它可能忽视了距离当前位置较远但仍然相关的信息。

二、(Stride Attention)

步幅注意力是一种介于全局注意力和局部注意力之间的注意力机制。在这种机制中,模型会跳过一些位置,只关注那些"步幅"位置。这种方法的优点是它提供了一种平衡的方法,可以同时关注多个位置,但又不需要处理所有的位置。然而,和局部注意力一样,它可能会忽视一些重要的信息。

三、(Global Attention)

全局注意力是一种关注输入序列中所有位置的注意力机制。在这种情况下,模型会对所有可能的位置进行计算,以决定应该赋予每个位置多大的注意力。这种方法的优点是它能够捕获序列中所有位置的信息,这对于理解复杂的模式和关系可能是必要的。然而,它的缺点是计算效率低,因为需要处理大量的位置。

到底哪种 attention 最好呢?小孩子才做选择...对于一个网络,有的 head 可以做 local attention,有的 head 可以做 global attention... 这样就不需要做选择了。还是要具体问题具体分析,看下面几个例子:

Longformer 就是组合了上面的三种 attention

Big Bird 就是在 Longformer 基础上随机选择 attention 赋值,进一步提高计算效率

e1f597c405e19b29da1098df9abee16c.png

六、Can we only focus on Critical Parts?

上面集中方法都是人为设定的哪些地方需要算 attention,哪些地方不需要算 attention,但是这样算是最好的方法吗?并不一定。对于 Attention Matrix 来说,如果某些位置值非常小,我们可以直接把这些位置置 0,这样对实际预测的结果也不会有太大的影响。也就是说我们只需要找出 Attention Matrix 中 attention 的值相对较大的值。但是如何找出哪些位置的值非常小/非常大呢? 

3fb977c69f7ec629755daa89c108f74a.png

七、Clustering

一、什么是Clustering

聚类(Clustering)是无监督学习的一种方法,它试图找出数据中的潜在分组。这些分组是基于数据对象之间的相似性或距离度量来确定的。在聚类中,我们不需要预先知道每个观察对象的类别或标签,这是它区别于监督学习的主要特征。

利用Clustering(聚类)的方案,即先对 query 和 key 进行聚类。属于同一类的 query 和 key 来计算 attention,不属于同一类的就不参与计算,这样就可以加快 Attention Matrix 的计算。比如下面这个例子中,分为 4 类:1(红框)、2(紫框)、3(绿框)、4(黄框)。

7da02dc2890f5d528a1475d785cea892.png

f1f71eebe190332dfbbd232fed2a76a9.png

八、Learnable Patterns

有没有一种将要不要算 attention 的事情用 learn 的方式学习出来呢?有可能的。我们再训练一个网络,输入是 input sequence,输出是相同长度的 weight sequence。将所有 weight sequence 拼接起来,再经过转换,就可以得到一个哪些地方需要算 attention,哪些地方不需要算 attention 的矩阵。有一个细节是:某些不同的 sequence 可能经过 NN 输出同一个 weight sequence,这样可以大大减小计算量。

5525b7c63dab1f651f0ade3b51fb60f4.png

上述我们所讲的都是 N*N 的 Matrix,但是实际来说,这样的 Matrix 通常来说并不是满秩的,也就是说我们可以对原始 N*N 的矩阵降维,将重复的 column 去掉,得到一个比较小的 Matrix。

31a9bd37e1c31c61f9ff5b049313fd57.png

具体来说,从 N 个 key 中选出 K 个具有代表的 key,每个 key 对应一个 value,然后跟 query 做点乘。然后做 gradient-decent,更新 value。

为什么选有代表性的 key 不选有代表性的 query 呢?因为 query 跟 output 是对应的,这样会 output 就会缩短从而损失信息。

13461457554ce611baf43dec43b0397f.png

 九、数学推导

我们把 softmax 拿回来。原来的 self-attention 是这个样子,以计算b1为例:

十、代码示例

一、数据准备

import torch
x = [
      [1, 0, 1, 0], # Input 1
      [0, 2, 0, 2], # Input 2
      [1, 1, 1, 1]  # Input 3
     ]
x = torch.tensor(x, dtype=torch.float32)#将之前定义的 Python 列表转换为了一个 PyTorch 张量,并且设置了这个张量的数据类型为 torch.float32(32位浮点数)。

二、 初始化权重

w_key = [
      [0, 0, 1],
      [1, 1, 0],
      [0, 1, 0],
      [1, 1, 0]
]
w_query = [
      [1, 0, 1],
      [1, 0, 0],
      [0, 0, 1],
      [0, 1, 1]
]
w_value = [
      [0, 2, 0],
      [0, 3, 0],
      [1, 0, 3],
      [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

三、导出key, query and value的表示

keys = x @ w_key #矩阵乘法
querys = x @ w_query #矩阵乘法
values = x @ w_value #矩阵乘法
print(keys)
print(querys)
print(values)

运行结果如下:

四、计算输入的attention scores

attn_scores = querys @ keys.T

运行结果如下:

五、计算softmax 

from torch.nn.functional import softmax
attn_scores_softmax = softmax(attn_scores, dim=-1) #softmax 操作是在最后一个维度上进行的
attn_scores_softmax = [
      [0.0, 0.5, 0.5],
      [0.0, 1.0, 0.0],
      [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)

六、将attention scores乘以value

weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]

七、对加权后的value求和以得到输出

outputs = weighted_values.sum(dim=0)

运行结果如下:

总结

注意力机制被广泛应用于各种模型中,它可以帮助模型更有效地处理输入信息。Local Attention、Global Attention 和 Stride Attention 是注意力机制的三种常见形式。通过本周的学习,加深了对自注意力机制的理解,通过代码示例,能够掌握自注意力机制的计算流程以及步骤。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值