cs231n'18: Assignment 2 | Batch Normalization

Assignment 2 | Batch Normalization上文吐槽BN部分讲的太烂,2018年果然更新了这一部分,slides里加了好多内容,详见Lecture 6的slides第54到61页,以及Lecture 7的slides第11到18页,这里结合着原始论文和作业,把BN及其几个变种好好总结一下。Batch NormalizationTrain前面的作业中已...
摘要由CSDN通过智能技术生成

Assignment 2 | Batch Normalization

上文吐槽BN部分讲的太烂,2018年果然更新了这一部分,slides里加了好多内容,详见Lecture 6的slides第54到61页,以及Lecture 7的slides第11到18页,这里结合着原始论文和作业,把BN及其几个变种好好总结一下。

Batch Normalization

Train

前面的作业中已经见识到了,weight初始化时方差的调校真的是很麻烦,小了梯度消失不学习,大了梯度爆炸没法学习。
即使开始初始化的很好,随着学习的深入,网络的加深,每一层的方差已经不再受控;另外,特别是对于刚开始的几层,方差上稍微的变化,都会在forward prop时逐级放大的传递下去。
作业中只是三五层的小网络,要是几十上百层的网络,可以想象学习几乎是不可能的。

既然每一层输入的方差会产生如此多的问题,这就产生了第一个想法,何不将每一层的输入直接标准化为0均值单位方差。由于NN的train多是基于mini-batch的,所以这里标准化也是基于mini-batch。

输入x是包含N个sample的mini-batch,每个sample有D个feature。对每个feature进行标准化,即:

μjσ2j=1Ni=1Nxi,j=1Ni=1N(xi,jμj)2 μ j = 1 N ∑ i = 1 N x i , j σ j 2 = 1 N ∑ i = 1 N ( x i , j − μ j ) 2

标准化后的输出为:
x^=xμjσ2j+ϵ x ^ = x − μ j σ j 2 + ϵ

但是但是但是,这里武断的使输入均值为0,方差为1真的是最好的选择么?不一定。如果不是最好的选择,
设为多少是最好的选择呢?不知道。不知道的话怎么办呢?
那就让NN自己去学习一个最好的去呗。所以才有了下一步:

y=γx^+β y = γ ⋅ x ^ + β

其中, γ γ β β 是要学习的参数,将输入的均值和方差从(0,1)又拉到了 (γ,β) ( γ , β )

所以,通常说起来BN是一层,但是我认为,BN是两层:Normalization Layer和Shift Layer,这两层是紧密相连,不可分割的。其中,Normalization Layer将输入的均值和方差标准化为(0,1),Shift Layer又将其拉到 (γ,β) ( γ , β ) 。这里, (γ,β) ( γ , β ) 和其他的weight、bias一样,都是通过backprop算梯度,然后再用SGD等方法更新学习得到。

好,这里强调两个问题,也是我第一遍看paper时的疑惑,也是2017年视频中那位小姑娘讲课时犯的错误:

  1. 一提到BN层的作用,马上想到的是:将输入映射为0均值单位方差的高斯分布。错!首先它不一定是高斯分布,可以是任意的分布,BN仅仅改变均值方差,不改变分布。其次,均值方差不是(0,1),而是 (γ,β) ( γ , β ) 。说(0, 1)的是忘记了shift这一层。
  2. 原文中有一句,还打了斜体:

    To address this, we make sure that the transformation inserted in the network can represent the identity transform.


当时看的时候就不明白,既然费半天劲减均值除方差,怎么这里又要 “represent the identity transform”? 而且加上后边的 (γ,β) ( γ , β ) 操作,就更看不懂了。其实这里漏看了一个 “can” 。既然 (γ,β) ( γ , β ) 是学习来的,它们当然可以是原始输入的均值方差了,所以BN有表达一个identity transform的能力,而不是必须要表达一个identity transform。 总结一下:
input:
      x: (N, D)
intermediates:
      mean: (1, D)  
          mean = np.mean(x, axis=0)
      var: (1, D)
          var = np.var(x, axis=0)
      xhat: (N, D)
          xhat = (x - mean) / (np.sqrt(var + eps))
learnable params:
      gamma: (1, D)
      beta: (1, D)
输出:
      y = gamma * xhat + beta

Test

在test时,就没有mini-batch可用来算 μ μ σ2 σ 2 了,此时常用的方法是在train的过程中记录一个 μ μ σ2 σ 2 的滑动均值在test的时候使用。 BN通常放在FC/Conv之后,ReLU之前。

Backprop

BN的backprop是这次作业的难点,还要用两种方法做,这里一步一步尽量详细地把推导过程写出来。
dβ d β
dβ d β 用维度分析法:
y=γx^+β y = γ ⋅ x ^ + β
其中 y y 形如(N, D), γ β β 形如(D,), x^ x ^ 形如(N, D),所以 dβ d β 必然为:
dbeta = np.sum(dout, axis=0)
这里就不赘述了。
dγ d γ
其实 dγ d γ 也可以用维度分析法得到, dy d y dx^ d x ^ 都形如(N, D),而 dγ d γ 形如(D,),显然 dγ d γ 应为:
dgamma = np.sum(xhat * dout, axis=0)
这里还是把过程写一下吧
y11y21yN1y12y22...yN2............y1Dy2DyND=[γ1γ2...γD]x11x21xN1x12x22...xN2............x1Dx2DxND [ y 11 y 12 . . . y 1 D y 21 y 22 . . . y 2 D . . . . . . y N 1 y N 2 . . . y N D ] = [ γ 1 γ 2 . . . γ D ] ⋅ [ x 11 x 12 . . . x 1 D x 21 x 22 . . . x 2 D . . . . . . x N 1 x N 2 . . . x N D ]
展开可得:
y11=γ1x11,y21=γ1x21,y12=γ2x12,y22=γ1x22,...... y 11 = γ 1 ⋅ x 11 , y 12 = γ 2 ⋅ x 12 , . . . y 21 = γ 1 ⋅ x 21 , y 22 = γ 1 ⋅ x 22 , . . .
由此可得:
Lγq=Lyyγq=i,jLyijyijγq ∂ L ∂ γ q = ∂ L ∂ y ⋅ ∂ y ∂ γ q = ∑ i , j ∂ L ∂ y i j ⋅ ∂ y i j ∂ γ q
而仅当 j=q j = q 时有
yijγq=xiq ∂ y i j ∂ γ q = x i q
其余均为0,故:
Lγq=i=1NLyiqyiqγq=i=1Nxiqdyiq ∂ L ∂ γ q
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值