LongNet实现10亿token的关键在于提出了dilated attention机制。
具体来说,dilated attention与标准attention的主要区别在于:
1. Dilated attention的注意力矩阵不是全连接的,而是按距离进行了稀疏化。即只允许token与距离其固定步长(称为dilation)的其他token进行attention。
举一个例子:
假设我们有一个序列,长度为8个token:
[x1, x2, x3, x4, x5, x6, x7, x8]
标准的self-attention会计算每个token与所有其他token之间的attention。
而dilated attention假设dilation factor设置为:1,2,4。
则x1会attend到:
x1自己、x2、x5
x2会attend到:
x2自己、x3、x6
x3会attend到:
x3自己、x4、x7
x4会attend到:
x4自己、x5、x8
以此类推。
可以看出,dilated attention通过设定dilation factor的疏密,控制了每个token可以attend的范围。
近距离有更密集的interaction,远距离变得更稀疏。但由于dilation factor呈指数级增长,仍然保证了全局建模能力。这种机制降低了计算量,同时保持了Transformer的表达能力。
2. 随着距离的增大,dilation也指数级增大,从而实现了对数复杂度。例如,dilation可以设置为1,2,4,8等等。
3. 通过这种方式,近距离的token可以高效地建模局部依赖,远距离的token可以捕捉全局依赖,实现了高效的长距离建模。
4.Dilated attention的注意力矩阵可以高效存储并实现快速 querying。
5.可以无缝地取代标准self-attention,并配合存在的优化技术,如局部注意力、稀疏注意力等。
6.可以与分布式训练无缝衔接,不同device处理矩阵的不同部分。
ref:GitHub - kyegomez/LongNet: Implementation of plug in and play Attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens"
https://arxiv.org/abs/2307.02486