挖来~~ 通过修改SelfAttention的执行逻辑,就可以节省大量的激活值显存开销。
通过修改SelfAttention的执行逻辑,可以节省大量的激活值显存开销。
这篇文章的消除方法来自于2021年12月10日谷歌放到arxiv上的文章self attention does not need O(n^2) memory. 该方法巧妙地使用了小学学到的加法分配率,将self attention中的固定激活值降到了O(1)的程度。[1]
Self Attention 固定激活值显存分析
Hugging face Transformers中,SelfAttention 内核实现
表格中只列举了会实测中产生激活值的操作,其中B为Batch_size,L为sequence_length,H为hidden_size,m为SelfAttention中head的数量。
SelfAttention 固定激活值显存优化
1. Prerequisites
1.1 Softmax 计算过程
写成伪代码则为:
2. 显存优化
Google采用了一个非常简单的方法来节省Attention核中的大量的显存开销,具体计算过程为:
来避开原始的实现中所产生的A和S矩阵。
写成伪代码:
一个可行的PyTorch api实现,但是效率很低很低,不可能用的。效率想要高估计还是需要用CUDA去写个算子...按照文章的说法,实现的好的话,推断的时候是可以比原始方法要快的,但是就训练而言,这里在后向过程中肯定需要进行丢失信息的重计算,论文里可以预见的会被原始方法慢两倍。
不想写CUDA又想要提升性能的话,可以考虑narrow的时候多取几行或者几列,跟GPU的核数对应上应该比较合适(文章里是4096,也忒大了),然后换成einsum的张量乘法实现可调整遍历窗口大小的优化方法。
总结
- 这个方法跟原始方法在逻辑上是等价的,而且计算复杂度也是一致的。
- 使用的时候需要注意在计算指数的时候可能会存在的溢出问题(这个原始实现里也有),因此文章里面的实现在做指数运算前减去了最大的A_ij值。
- 收敛性相同,且在训练小Transformer时有4个百分点的速度提升。
- 需要在Backward的时候重计算丢失掉的信息,这里可能会影响到dropout,所以dropout的结果我猜肯定在前向的时候是不能被丢弃的。
- 推理系统的福音,可以调整并降低中间产生的激活值峰值,同时保证一定的推理速度。