看本篇文章需要有一定的基础,需要了解keras的基本框架,Attention的基本原理。(这些我可以分享我在学习过程中比较重要的篇章)
需要理解三维矩阵在空间展开是什么样子,四维的呢?这个问题困扰了我一段时间,后来看到一篇很通俗易懂的文章:(现在找不到了,有空我自己写一个补上)
- MultiAttention keras实现
多头注意力的关键就是将进入attention的q,k,v进行映射,再进行attention计算,重复多次即为多头注意力。但是在TensorFlow中重复的方式并不会并行计算,而是串行计算,这也就违背了attention一开始用并行和强力的上下文依赖性两个优点秒杀CNN,RNNd的初衷。
因此在写代码的时候需要按照另一种方式来编写,先将进入attention的q,k,v映射到一个比较高的维度,比如512或者768,再将其reshape成多个头,再进行矩阵运算,在同一个矩阵中计算运算的完全可以并行计算。
具体可以见代码及注释:
class MyMultiHeadAttention(Layer):
def __init__(self,
heads,
head_size,
key_size=None,
use_bias=True,
kernel_initializer='glorot_uniform',
**kwargs
):
self.heads = heads # 多头的数目
self.head_size = head_size # v的头的大小
self.out_dim = heads * head_size # 最终的输出维度
# 一般情况下和head_size是一样的 key_size的作用可以参考苏神的低秩分解:https://kexue.fm/archives/7325
self.key_size = key_size or head_size # qk每个头的大小
self.use_bias = use_bias
self.kernel_initializer