Transformer 时间复杂度和空间复杂度高 o(n^2) 是因为每一步 它都需要计算该步与之前的所有context的attention 信息。这就导致Transformer在序列长度上很难扩展,对于字符级别的语言模型来说上千个token的输入很常见,transformer就无法好好使用了。
在扩展transformer 处理序列长度上 前面有2篇文章 Sparse Transformer 和 Transformer XL 分别从不同的角度来解决问题, Sparse Transformer 是通过稀疏attention 机制 减少来原始transformer计算的时间和空间复杂度 从而使得transformer能够处理长序列,而Transformer XL则是考虑到建立segment之间的依赖性,这样就算把长序列切割成一个个短的segment,也可以通过循环机制建立相互之间的依赖性 使得transformer 对长序列的处理效果提升。
而本文是facebook 团队提出的改进机制 其实与SparseTransformer 的方向类似,通过减少时间和空间复杂度来解决这个问题的。但是本文提出的观点是自适应的宽度的注意力学习机制和动态注意力机制。其实就是每个head 只学习关注的部分,不怎么关注的部分,通过mask使得权重将为0。这个只学习关注的部分的想法来源于对transformer 各个head 学到的不同的模块分析。如下图(横轴-100 到-20是指上下文的距离) 可以看到 headA 只关注的是local的信息(离它最近的20个 权重高,前面的80个的权重很低),head B关注的是global信息(最近的高,前面的也高)。所以作者想 是否可以设置一种自适应的机制 对于每个head不关注的地方 干脆不要学习,这样也可以节省 空间和时间(比如head A的前面的80个 权重很低的 context 干脆不要学习,省掉这部分计算)。
借助mask 机制来实现对不关注的context进行屏蔽
自适应宽度的实现借助了下面的mask函数:其中参数x是 context 与当前token的距离。mask函数就是把上下文和当前token 之间的距离 映射到[0,1]之间。
在计算attention的权重的时候,使用context和当前位置的距离来作为mask函数的输入。在学习过程中,z是需要学习的参数,对每个head,z都是不同的。attention权重的计算如下图:
在损失函数中,给z添加一个L1 penalization:
也可以使用网络的形式 动态的学习参数z,即z是基于当前输入的一个输出。这种方式被称为动态宽度。
结果:
可以看到,相对于普通的Transformer来说,参数量并没有太大的降低,但是计算量却会有三四个数量级的减少。
随着input长度的增长,平均宽度,计算量的对比如下图:可以看到,即使输入边长,计算量和平均的注意力宽度变化很小。
不同层次上的平均宽度如下:可以看到,越是高层,attention的宽度越大。
参考:
Adaptive Attention Span in Transformers
Ttransformer 之自适应宽度注意力