normalization归一化算子和方差计算数值稳定性方法

背景

在深度学习模型中,很多算子涉及方差计算,如layer norm, instance norm, group norm等等。方差计算有不同的计算方法,而不同的计算方法具有不同的数值稳定性和计算准确度。instance norm, group norm等算子的求和元素N显著更大,面临数值稳定性问题更加显著。

参考

Algorithms for calculating variance

https://en.wikipedia.org/wiki/Catastrophic_cancellation

Algorithms for computing the sample variance: Analysis and recommendations

Numerically stable parallel computation of (co-)variance

https://en.wikipedia.org/wiki/Kahan_summation_algorithm

方差计算公式

方差计算一般有两种方法:

第一种方法需要先对数据遍历一遍求均值,再对数据遍历一遍求方差,因此需要访问两次数据。因为多次内存访问会明显降低计算性能,第二种计算方法只需要遍历一次数据求x和x平方的均值,因此计算性能更高,实际使用很多,这里称之为text_book算法。

Catastrophic cancellation灾难性抵消

https://en.wikipedia.org/wiki/Catastrophic_cancellation

由样本x计算得到的L1和L2,L1和L2分别带有一定的误差,同时L1和L2的数值比较接近。即使L1和L2与理论值的相对误差较小,但是L1减去L2的结果S_a与其理论上S_e相比可能具有显著大的相对误差。

在方差的计算过程中,如果x为0均值,那么x平方的均值E(x^2)和x均值的平方E(x)^2不会太接近,cancellation问题不大。但是如果x均值具有明显的偏置,那么就导致E(x^2)和E(x)^2很接近。实际上深度学习的激活值通常不是0均值的,具有明显的bias。

简单改进:减去偏置

Var(x) = Var(x - K)

K可以选一个接近Mean的估计。实际上选择x第一个数也行。

the closer K is to the mean value the more accurate the result will be, but just choosing a value inside the samples range will guarantee the desired stability. If the values xi-K are small then there are no problems with the sum of its squares, on the contrary, if they are large it necessarily means that the variance is large as well. In any case the second term in the formula is always smaller than the first one therefore no cancellation may occur.[2]

Welford's online algorithm

这里online algorithm的含义是计算出已有数据的方差后,来一个新的元素,可以直接基于之前计算的结果增量计算,而不用重新遍历之前的数据计算。

An example Python implementation for Welford's algorithm is given below.

import numpy as np


def var_text_book(data, dtype=None):
    mean2 = np.array(0, dtype=data.dtype)
    mean = np.array(0, dtype=data.dtype)
    if dtype is not None:
        mean2 = mean2.astype(dtype)
        mean = mean.astype(dtype)

    n = len(data)
    for i in range(n):
        x = data[i]
        mean2 += x*x
        mean += x

    mean2 /= n
    mean /= n
    var = mean2 - mean * mean
    var1 = var * n / (n-1)
    return mean, var, var1


def welford_update(existing_aggregate, new_value):
    # For a new value new_value, compute the new count, new mean, the new M2.
    # mean accumulates the mean of the entire dataset
    # M2 aggregates the squared distance from the mean
    # count aggregates the number of samples seen so far
    (count, mean, M2) = existing_aggregate
    count += 1
    delta = new_value - mean
    mean += delta / count
    delta2 = new_value - mean
    M2 += delta * delta2
    return (count, mean, M2)


def welford_var(data, dtype=None):
    count = np.array(0, dtype="int32")
    mean = np.array(0, dtype=data.dtype)
    M2 = np.array(0, dtype=data.dtype)
    if dtype is not None:
        mean = mean.astype(dtype)
        M2 = M2.astype(dtype)

    existing_aggregate = (count, mean, M2)

    for new_value in data:
        existing_aggregate = welford_update(existing_aggregate, new_value)

    count, mean, M2 = existing_aggregate
    n = len(data)
    return mean, M2/n, M2/(n-1)


data = np.random.randn(20480) + 4

data = data.astype(np.float16)

var_ref = np.var(data.astype(np.float32))

mean_text, var_text, var1_text = var_text_book(data)
mean_text, var_text_fp32, var1_text_fp32 = var_text_book(data, "float32")

mean_welford, var_welford, var1_welford = welford_var(data)
mean_welford, var_welford_fp32, var1_welford_fp32 = welford_var(data, "float32")

This algorithm is much less prone to loss of precision due to catastrophic cancellation, but might not be as efficient because of the division operation inside the loop. For a particularly robust two-pass algorithm for computing the variance, one can first compute and subtract an estimate of the mean, and then use this algorithm on the residuals.

The parallel algorithm below illustrates how to merge multiple sets of statistics calculated online.

Parallel algorithm并行算法

def parallel_variance(n_a, avg_a, M2_a, n_b, avg_b, M2_b):
    n = n_a + n_b
    delta = avg_b - avg_a
    M2 = M2_a + M2_b + delta**2 * n_a * n_b / n
    var_ab = M2 / (n - 1)
    return var_ab

每个硬件计算单元计算独立使用上面的Welford算法计算,然后用该方法把每个计算单元的结果合并起来得到一个最终的结果。

实践测试

使用上面的测试代码进行一些测试,可以发现:

1,如果数据x均值为0,那么text_book算法和Welford算法准确度接近,并且在x方差较大时均容易溢出导致inf/nan。

2,如果x均值不为0,那么Welford准确度显著高于text_book算法。并且Welford在x绝对值较大时也不容易溢出,而text_book算法很容易溢出。因此总体来说Welford数值稳定性更高,能够适应更加多变的数据分布情况。但是Welford计算复杂度更高,这个可能导致计算速度不如text_book算法。

3,如果硬件能支持float32的情况下,如果对均值和均值平方累加器使用float32类型,在深度学习的场景下,text_book算法和Welford算法都能获得比较高的准确度。而只能只用float16时,为了结果准确性应该使用Welford算法。但使用float16的Welford算法准确度是远不如float32的算法的。这导致一个问题:在计算速度上,使用float32的text_book算法还是float16的Welford算法哪个更快是需要进一步评估的,如果前者更快,那么float32的text_book算法仍然可以实现数值精度和性能的双重保证。

  • 16
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
内容介绍 项目结构: Controller层:使用Spring MVC来处理用户请求,负责将请求分发到相应的业务逻辑层,并将数据传递给视图层进行展示。Controller层通常包含控制器类,这些类通过注解如@Controller、@RequestMapping等标记,负责处理HTTP请求并返回响应。 Service层:Spring的核心部分,用于处理业务逻辑。Service层通过接口和实现类的方式,将业务逻辑与具体的实现细节分离。常见的注解有@Service和@Transactional,后者用于管理事务。 DAO层:使用MyBatis来实现数据持久化,DAO层与数据库直接交互,执行CRUD操作。MyBatis通过XML映射文件或注解的方式,将SQL语句与Java对象绑定,实现高效的数据访问。 Spring整合: Spring核心配置:包括Spring的IOC容器配置,管理Service和DAO层的Bean。配置文件通常包括applicationContext.xml或采用Java配置类。 事务管理:通过Spring的声明式事务管理,简化了事务的处理,确保数据一致性和完整性。 Spring MVC整合: 视图解析器:配置Spring MVC的视图解析器,将逻辑视图名解析为具体的JSP或其他类型的视图。 拦截器:通过配置Spring MVC的拦截器,处理请求的预处理和后处理,常用于权限验证、日志记录等功能。 MyBatis整合: 数据源配置:配置数据库连接池(如Druid或C3P0),确保应用可以高效地访问数据库。 SQL映射文件:使用MyBatis的XML文件或注解配置,将SQL语句与Java对象映射,支持复杂的查询、插入、更新和删除操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值