Transformers 中的Softmax 和 Layer Norm 如何并行?

1.Softmax 如何并行?

        Softmax 计算公式:

        安全的 Softmax 运算:

        softmax 有个问题,那就是很容易溢出。比如采用半精度,由于float16的最大值为65504,所以只要x>=11,那么softmax就溢出了。即使是float32,x也不能超过88。

        好在 exp 有这么一个性质,那就是

        

         根据这个性质,可以在分子分母上同时除以一个数,这样可以将  的范围都挪到非正实数域。

这样,就可以保证计算 softmax 时的数值稳定性。

这个算法可以分成三次迭代来执行。

  1. 求 x 的最大值 m

       2. 计算 softmax 分母

        3.求对应位置的 softmax

        分析上面的步骤,可以发现,如果是不做任何优化的话,至少要进行和 GPU 进行6次通信(3次写入,3次写出)。

        如果对每一步的for 循环进行一些并行切分的的话,还要加上 reduce_sum 和 reduce_max 之类的通信成本。

        是否能将某些操作进行融合,减少通信呢?按照之前 layernorm 并行的经验,我们需要寻找一个 Online Algorithm。

Online Softmax

        2018年 Nvidia 提出了《Online normalizer calculation for softmax》

        既然是 Online 的算法,我们需要找出递归的表达式。

        对于第二步中的我们期望去掉这个式子对

的依赖。

设 ,,注意,这里减去的全局最大值变成了当前最大值。这个式子有如下的性质:

        还能不能进一步融合算子呢?没办法了,因为第二步的分母依赖于第一步的计算。

        但是可以借助 GPU 的 share memory 来存储中间结果,将上面的两步只用一个 kernel 实现,这样就只需要与 global memory 通信两次,一次写入数据,一次读取结果。

整体来说,有两个重要的优化点:

  1. 将前两步的算子融合,减少 Reduce_max 和 Reduce_sum 之类的通信成本。

  2. 借助 share memory 存储中间结果,减少与 global memory 的通信成本。

        这一篇只是从数学上给出了一些 Softmax 的并行理论基础。具体实现还有很多细节上的优化点,比如:

感兴趣的可以看看 oneflow 的一个 softmax 深度优化:https://www.oneflow.org/a/share/jishuboke/54.html . 源代码在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh

还有 Nvidia 自己实现的一个可读性很好的版本:https://github.com/NVIDIA/FasterTransformer/blob/release/v1.0_tag/fastertransformer/cuda/open_attention.cu#L189-L268 但是速度没有 oneflow 的好。

2.Transformers 中的 Layer Norm 可以并行加速么?

        这个问题我之前觉得可以加速,而且给出了一个简单的实现方案。后来看 Transformers 的一些 GPU 训练的代码后,才发现我真是 too young too simple, sometimes even naive。

        layernorm 的计算,重点就是计算均值和方差。分两步:

实际上的并行方案

上面的方案当然没什么问题,但是并不是最优的。

上面的算法需要遍历2次数据,一次计算均值,一次计算方差。能不能只遍历数据一次就能并行的把均值和方差算出来呢?

相信你会立马想到这个公式:

 

        并行的时候,一边算平方和,一边算全部的和。最后平方和与均值都可以算出来,然后按公式一减就出来了,看上去十分的 Perfect。

        但是这个公式只是理论上很完美,受限于计算机计算精度的问题,这个公式当两个平方项都很大的时候,精度会失真,导致算出来的方差很不稳定,甚至有可能是负数。后面会有代码演示数值稳定性的问题。

        那能像上一节的算法那样,分别计算均值和方差最后聚合么?似乎有些反直觉,不需要知道全局的均值就可以计算方差。但是我们要相信数学家的折腾能力,搞出了无数匪夷所思的东西。就连加百列号角(Gabriel's Horn)这种鬼玩意都能搞出来,数学有无限的可能。(注:加百列号角 Gabriel's Horn 的体积是有限的,但是表面积是无限的。)

        Transfromers 无论是在 pytorch,还是在 apex,还是在其他一些加速框架比如 oneflow 中,都采用了 Welford online Algorithm。这个算法是 Welford 在1962年发表的《Note on a Method for Calculating Corrected Sums of Squares and Products》中提出。他给出的算法,可以在一个集合新增一个元素的时候,均值和方差的不需要把所有的数都遍历一遍,而是根据之前集合的均值和方差就可以直接计算出来。

        而在1972年,Chan 发表了《Updating Formulae and a Pairwise Algorithm for Computing Sample Variances》,可以认为是 Welford Algorithm 的一个升级版本,可以根据两个集合的均值和方差直接计算出整体的均值和方差。当然如果两个集合中,某一个集合只有一个元素,算法就退化成 Welford Algorithm 了。这个算法为大规模并行计算均值和方差提供了理论基础。

        由于 Welford's Algorithm 是 Chan's Algorithm 的一个特例,所以下面简单说一下 Chan's Algorithm 是怎么一回事。

        这里首先给出一个定义,定义与均值差的平方和为

        也就是说我们只需要两个集合各自的均值和 M2 我们就可以计算出方差。

        上面这个式子怎么来的呢?我们来证明一下:

        证明完了好像也没那么神奇,陷入了人生三大错觉之一:我上我也行,只恨自己生的太晚。事后诸葛亮就是这么自信满满。 

代码学习

        由于Chan 和 Welford 算法的并行体质,Nvidia 的 Apex 库率先实现了这个方法,叫做 Fused Layer Norm。为啥叫 Fused ?因为把所有的计算都融合(fuse)到一个核函数里了,不需要与 CPU 来回通信。可以重点看代码开头的 cuWelfordOnlineSum 和 cuChanOnlineSum 两个函数。对应 python 代码入口为 apex.normalization.fused_layer_norm.FusedLayerNorm。代码见:https://github.com/NVIDIA/apex/blob/c3fad1ad120b23055f6630da0b029c8b626db78f/csrc/layer_norm_cuda_kernel.cu#L670

        pytorch 后来也实现实现了,可以看 cuWelfordOnlineSum 和 cuWelfordCombine 两个函数,代码见:https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer_norm_kernel.cu

        Oneflow 后来又进一步根据输入的大小优化了 Fused Layer Norm 的性能,代码见:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh

        所以现在我们使用 pytorch 和其他加速库的 layernorm 函数底层已经实现了并行。我们这些调包侠在用 python 写代码的时候,要记住,哪有什么岁月静好,都是 C++ 和 Cuda 大佬们在负重前行

        下面我用 python 模拟了一下 c++ cuda 的实现,同时测试了一下在数字比较大的时候的数值稳定性问题。可以发现,用平方和减去均值平方的方法,方差就算错了,成为了负值。

  测试普通数字...
    全局均值: -0.10384651739409387
    Welford 并行全局均值: -0.10384651739409384
    串行全局方差: 0.8165221946938586
    平方差串行全局方差: 0.816522194693858
    Welford 并行全局方差: 0.8165221946938584
    --------------------------------------------------
    测试大数...
    全局均值: 999999999.8961536
    Welford 全局均值: 999999999.8961536
    串行全局方差: 0.8165221933047772
    平方差串行全局方差: -512.0
    Welford 并行全局方差: 0.8165221874239014
    --------------------------------------------------

        核心代码如下,全部的代码实在是有些又臭又长,就放在开篇提到的电子书里了。

# 核心代码
    def welford_combine(val, mean, m2, count):
        """新增一个数"""
        count += 1
        delta1 = val - mean
        mean += delta1 / count
        delta2 = val - mean
        m2 += delta1 * delta2
        return mean, m2, count
    
    def welford_combine_two(b_mean, b_m2, b_count, mean, m2, count):
        """合并两个集合"""
        if b_count == 0:
            return mean, m2, count
        new_count = count + b_count
        nb_over_n = b_count / new_count
        delta = b_mean - mean
        mean += delta * nb_over_n
        m2 += b_m2 + delta * delta * count * nb_over_n
        count = new_count
        return mean, m2, count

  • 21
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值