神经网络中的权重初始化

神经网络的权重初始化

从神经网络输入和输出尽量都要有相同的方差出发,但均值很难保持一致(由于有一些非负的激活函数)。而且现有的标准化策略也是起到了同样的作用,如BN,LN等,都是努力将中间层的输出的方差和均值限定为1和0,但是最近的一些BN的工作(如用在Google T5中的RMS Norm)尝试了去掉减去均值的操作,反而会有提点的效果。说明保持均值一致并不是必须的。

💡
一个直观的猜测是,center操作,类似于全连接层的bias项,储存到的是关于预训练任务的一种先验分布信息,而把这种先验分布信息直接储存在模型中,反而可能会导致模型的迁移能力下降。所以T5不仅去掉了Layer
Normalization的center操作,它把每一层的bias项也都去掉了。

对于一个输入含有m个神经单元的网络层,输出n个值。当没有激活函数时,模型为 y i = b i + ∑ j w i , j x i , j y_i=b_i+\sum_{j}w_{i, j}x_{i, j} yi=bi+jwi,jxi,j,可以计算输出值的期望和方差,主要是看二阶矩方差,当权重和偏置的初始化的均值都是0时,并且输入值 x x x的方差是1,可以得到 E [ y i 2 ] = m E [ w i , j 2 ] E[y_i^2]=mE[w_{i,j}^2] E[yi2]=mE[wi,j2],所以当神经网络权重初始化的方差是 1 / m 1/m 1/m时,输出值y的方差也是1,这就是Lecun初始化,可以让每一层的输出值的方差都保持在1左右。

对于Xaiver初始化也考虑了反向传播时神经网络权重方差的变化,得到了两个约束,对于第 i i i层的权重 W i W^i Wi和神经元个数 n i n_i ni有两个约束条件 n i V a r [ W i ] = 1 n_iVar[W^i]=1 niVar[Wi]=1, n i + 1 V a r [ W i ] = 1 n_{i+1}Var[W^i]=1 ni+1Var[Wi]=1,所以可以使 W i ∼ N ( 0 , 2 n i + n i + 1 ) W^i\sim N(0, \frac{2}{n_i+n_{i+1}}) WiN(0,ni+ni+12),当神经网络所有层的宽度是一样时,这两个个约束条件才会同时被满足。如果使用均匀分布初始化,例如初始权重从 U ( − 1 n , 1 n ) U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}) U(n 1,n 1)中采样,那么和上面一样有 E [ y i 2 ] = n i V a r ( W i ) = 1 / 3 E[y_i^2]=n_iVar(W^i)=1/3 E[yi2]=niVar(Wi)=1/3,所以为了使 E [ y i 2 ] E[y_i^2] E[yi2]的方差也为1,当W是从均匀分布中采样时,需要从 U ( − 6 n j + n j + 1 , 6 n j + n j + 1 ) U(-\frac{\sqrt{6}}{\sqrt{n_j+n_j+1}}, \frac{\sqrt{6}}{\sqrt{n_j+n_j+1}}) U(nj+nj+1 6 ,nj+nj+1 6 )中采样。

但是实际情况都有激活函数,如果算上激活函数后考虑神经网络权重的方差,如最简单的Relu激活函数,可以假设神经网络,如果神经网络权重初始化的分布还是正态分布的话,可以假设约有一半的神经单元被激活,则 E [ y i 2 ] = m 2 E [ w i , j 2 ] E[y_i^2]=\frac{m}{2}E[w_{i,j}^2] E[yi2]=2mE[wi,j2],此时就要求 E [ w i , j 2 ] = 2 m E[w_{i,j}^2]=\frac{2}{m} E[wi,j2]=m2,便是He初始化,考虑到反向传播, m = f a n _ i n m=fan\_in m=fan_in m = f a n _ o u t m=fan\_out m=fan_out,其中fan_in和fan_out分别是当前层的输入和输出。

相比于Lecun初始化使用“均值为0,方差为1/m的正态分布,其中m是当前层的神经单元个数”,还有一种NTK参数化的方法,用“均值为0、方差为1的随机分布”来初始化,但是将输出结果除以 m \sqrt{m} m ,也即是模型变为 y i = b j + 1 m ∑ i x i w i , j y_i = b_j+\frac{1}{\sqrt{m}}\sum_{i}x_{i}w_{i,j} yi=bj+m 1ixiwi,j,使用NTK参数化的好处是可以将所有参数放在 O ( 1 ) O(1) O(1)级别,所以可以设置较大的学习率。

对于transformer中的注意力机制除以 d \sqrt{d} d 的作用也便是稳定传播时候的二阶矩,因为 E [ ( q ⋅ k ) 2 ] = d E[(q\cdot k)^2]=d E[(qk)2]=d,其中d为q和k向量的维度。

标准化

相比于对权重初始化的这种微调来使模型的传播过程中的方差不会过大,还有一类比较直接粗暴的方法来使训练稳定,就是直接标准化,LN,BN等。

关于残差连接的二阶矩

残差连接x+F(x),如果x的方差为 σ 1 2 \sigma_1^2 σ12,F(x)的方差为 σ 2 2 \sigma_2^2 σ22,那么x+F(x)的方差就是 σ 1 2 + σ 2 2 \sigma_1^2+\sigma^2_2 σ12+σ22,会进一步放大,所以需要采取办法处理。原版本的Transformer和Bert直接采用了一种PostNorm的方法,也即是 x t + 1 = N o r m ( x t + F t ( x t ) ) x_{t+1}=Norm(x_t+F_t(x_t)) xt+1=Norm(xt+Ft(xt)).然而,这种做法虽然稳定了前向传播的方差,但事实上已经严重削弱了残差的恒等分支(递归到很多层之后),所以反而失去了残差“易于训练”的优点,通常要warmup并设置足够小的学习率才能使它收敛。一个针对性的改进称为Pre Norm,它的思想是“要用的时候才去标准化”,其形式为 x t + 1 = x t + F t ( N o r m ( x t ) ) x_{t+1}=x_t+F_t(Norm(x_t)) xt+1=xt+Ft(Norm(xt))

但是preNorm的效果不如Post Norm。pre Norm方法将深度网络退化成了一个宽网络,学习能力自然不如postnorm的深网络。

  • 32
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值