Welford算法解决layernorm问题

博客探讨了在将FP32计算转换为FP16时LayerNorm可能出现的nan或inf问题,介绍了Welford算法作为更高效的方差计算方法。通过代码示例展示了Welford算法如何更新均值和方差,以减少精度损失。文章强调了动态规划在优化高复杂度算法中的应用,并提供了两种计算方式的对比。
摘要由CSDN通过智能技术生成

背景

在利用框架做计算的时候,经常会遇到layernorm的问题,不知道有没有小伙伴发现,当fp32切到fp16的时候,有时候直接结果为nan或者为inf了,为此需要研究一下。

原理

其实layernorm的核心就是计算方差,定义的公式如下,但是实际上考虑到计算效率的问题,我们会采用FP32的公式来实现,具体可以节省多少计算量,有兴趣可以试一下,不过当把fp32强行切换到fp16的时候,就会出现误差,导致位置错误。
在这里插入图片描述

welford算法

之前很多框架采用的都是上面的fp32的算法,下面来看看一种新的计算方式
在这里插入图片描述

推导

均值的公式很好推导,就不展开了,直接来看方差的推导,根据FP32的公式可以知道;
在这里插入图片描述
接下来可以得到:
在这里插入图片描述
在这里插入图片描述
化简得到:
在这里插入图片描述
根据前面的均值知道:
在这里插入图片描述
替换后最终:
在这里插入图片描述
进一步化简为:
在这里插入图片描述

代码实现

import numpy as np


def welford_update(count, mean, M2, currValue):
    count += 1
    delta = currValue - mean
    mean += delta / count
    delta2 = currValue - mean
    M2 += delta * delta2
    return (count, mean, M2)


def naive_update(sum, sum_square, currValue):
    sum = sum + currValue
    sum_square = sum_square + currValue * currValue
    return (sum, sum_square)


x_arr = np.random.randn(100000).astype(np.float32)

welford_mean = 0
welford_m2 = 0
welford_count = 0
for i in range(len(x_arr)):
    new_val = x_arr[i]
    welford_count, welford_mean, welford_m2 = welford_update(welford_count, welford_mean, welford_m2, new_val)
print("Welford mean: ", welford_mean)
print("Welford var: ", welford_m2 / welford_count)

naive_sum = 0
naive_sum_square = 0
for i in range(len(x_arr)):
    new_val = x_arr[i]
    naive_sum, naive_sum_square = naive_update(naive_sum, naive_sum_square, new_val)
naive_mean = naive_sum / len(x_arr)
naive_var = naive_sum_square/ len(x_arr) - naive_mean*naive_mean
print("Naive mean: ", naive_mean)
print("Naive var: ", naive_var)

本质

利用动态规划来解决复杂度较高的算法问题。

参考

链接:https://cloud.tencent.com/developer/article/1877323

插曲

两个数组合并后的方差和均值:https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值