Paper Reading: DBN
Decorrelated Batch Normalization
CVPR2018. Github(Written in Lua). Paper.
1. Intuition
Ioffe & Szegedy(2017) 提出的Batch Normalization:
x i ^ = γ x i − μ σ 2 + ϵ , where μ = 1 m ∑ j = 1 m μ j , σ 2 = 1 m ∑ j = 1 m ( x j − μ ) 2 \hat{x_i}=\gamma\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}}, \quad\text{where}\quad\mu=\frac{1}{m}\sum_{j=1}^m\mu_j,\sigma^2=\frac{1}{m}\sum_{j=1}^m(x_j-\mu)^2 xi^=γσ2+ϵxi−μ,whereμ=m1∑j=1mμj,σ2=m1∑j=1m(xj−μ)2
但是最关键的问题是 m m m个维度之间关联度较高。
因此本文从ZCA白化入手提出了DBN, i.e. Decorrelated Batch Normalization:
x i ^ = Σ − 1 2 ( x i − μ ) \hat{x_i}=\Sigma^{-\frac{1}{2}}(x_i-\mu) xi^=Σ−21(xi−μ)
实现DBN的四个问题?
- DBN如何进行反馈传播?
- 为何选ZCA不选PCA?
- 如何计算 Σ − 1 2 \Sigma^{-\frac{1}{2}} Σ−21矩阵?
- 白化操作的样本量如何确定?
2. 算法细节
2.1 Notation
令
X
∈
R
d
×
m
\mathbf{X}\in\mathbb{R}^{d\times m}
X∈Rd×m,
d
d
d为维度,
m
m
m为mini-batch的大小,白化变换
ϕ
:
R
d
×
m
→
R
d
×
m
\phi:\mathbb{R}^{d\times m}\rightarrow\mathbb{R}^{d\times m}
ϕ:Rd×m→Rd×m可定义为:
ϕ
(
X
)
=
Σ
−
1
2
(
X
−
μ
⋅
1
T
)
\phi(\mathbf{X})=\Sigma^{-\frac{1}{2}}(\mathbf{X}-\mu\cdot\mathbf{1}^T)
ϕ(X)=Σ−21(X−μ⋅1T)
其中
μ
=
1
m
X
⋅
1
,
Σ
=
1
m
(
X
−
μ
⋅
1
T
)
(
X
−
μ
⋅
1
T
)
T
+
ϵ
⋅
I
\mu=\frac{1}{m}\mathbf{X}\cdot\mathbf{1},\Sigma=\frac{1}{m}(\mathbf{X}-\mu \cdot\mathbf{1}^T)(\mathbf{X}-\mu \cdot\mathbf{1}^T)^T+\epsilon\cdot\mathbf{I}
μ=m1X⋅1,Σ=m1(X−μ⋅1T)(X−μ⋅1T)T+ϵ⋅I。白化之后目的是
X
^
=
ϕ
(
X
)
s
.
t
.
X
^
X
^
T
=
I
\hat{\mathbf{X}}=\phi(\mathbf{X})\quad s.t. \hat{\mathbf{X}}\hat{\mathbf{X}}^T=I
X^=ϕ(X)s.t.X^X^T=I
那实现DBN该如何解决上述四个问题呢?
2.2 随机轴交换
Σ P C A − 1 2 = Λ − 1 2 D \Sigma^{-\frac{1}{2}}_{PCA}=\Lambda^{-\frac{1}{2}}\mathbf{D} ΣPCA−21=Λ−21D
正交特征向量 D = [ d 1 , ⋯ , d d ] \mathbf{D}=[\mathbf{d}_1,\cdots,\mathbf{d}_d] D=[d1,⋯,dd]是理论上是不能够被唯一确定的,但是PCA就选定了 Λ \Lambda Λ按照 σ 1 , ⋯ , σ d \sigma_1,\cdots,\sigma_d σ1,⋯,σd从大到小的顺序进行排列,即旋转坐标轴到方差最大的方向。
但是恰恰是由于PCA总是选取方差最大的方向,而神经网络的激活数值会随权重更新而改变,这使得其在不同的batch、iteration的过程中的所旋转的坐标轴方向也都是不同的,也就是所说的stochastic axis swapping。这也导致了如下的loss震荡、收敛效果差的问题:
2.3 ZCA白化的BP反向传播过程实现
公式推导+解释
鉴于上述情况选取ZCA白化 Σ Z C A − 1 / 2 = D Λ − 1 2 D T \Sigma^{-1/2}_{ZCA}=\mathbf{D}\Lambda^{-\frac{1}{2}}\mathbf{D}^T ΣZCA−1/2=DΛ−21DT.
那么ZCA白化后续的梯度下降的反向传播就应当如下所示:
2.3.1 Forward Pass (Appendix A.1)
从 x j → x j ^ x_j\rightarrow\hat{x_j} xj→xj^
2.3.2 Back Propagation (Appendix A.2)
BP过程从 L → x i L\rightarrow x_i L→xi
简化版本:
∂
L
∂
x
i
=
(
∂
L
∂
x
~
i
−
f
+
x
~
i
T
S
−
x
~
i
T
M
)
Λ
−
1
/
2
D
T
\frac{\partial L}{\partial \mathbf{x}_{i}}=\left(\frac{\partial L}{\partial \tilde{\mathbf{x}}_{i}}-\mathbf{f}+\tilde{\mathbf{x}}_{i}^{T} \mathbf{S}-\tilde{\mathbf{x}}_{i}^{T} \mathbf{M}\right) \Lambda^{-1 / 2} \mathbf{D}^{T}\\
∂xi∂L=(∂x~i∂L−f+x~iTS−x~iTM)Λ−1/2DT
2.4 算法流程
在每一轮训练中的前馈传播和反馈传播的伪算法。
2.4.1 前馈传播算法
白化是发生在每一个mini-batch中的, μ \mu μ和 Σ \Sigma Σ在每一个batch中都是迭代更新的。
- 在Forward Pass中先计算PCA,ZCA也是在PCA基础上计算得到。(与直接用 W z c a = Σ − 1 / 2 W_{zca}=\Sigma^{-1/2} Wzca=Σ−1/2相比,不知数值上有区别否?)
- 步骤10&11的 λ \lambda λ就是个相同的超参,不是严格的moving average迭代更新 μ n + 1 = n n + 1 μ n + 1 n + 1 X n + 1 \mu_{n+1}=\frac{n}{n+1}\mu_n+\frac{1}{n+1}X_{n+1} μn+1=n+1nμn+n+11Xn+1
2.4.2 反向传播算法
值得一提的是,在CNN算法中,DBN的输入形如 X C ∈ R h × w × d × m \mathbf{X}_C\in\mathbb{R}^{h\times w\times d\times m} XC∈Rh×w×d×m, h , w h,w h,w分别表示feature map的维度(height × \times ×width), d , m d,m d,m分别表示feature maps的数量和batch中样本的数量。
2.5 组白化
为了保证每个batch中有足够的样本数量来做白化操作,我们将激活层数值沿特征维度 d d d,划分成 k G ( k G < d ) k_G(k_G<d) kG(kG<d)个较小的组(防止出现batch样本数 m < < d m<<d m<<d的情况),并在每个组中进行白化。 k G = 1 k_G=1 kG=1时,DBN退化为BN。此时计算复杂度从 O ( d 2 max ( m , d ) ) O(d^2\max(m,d)) O(d2max(m,d))降至 O ( d k G ( k G 2 ( max ( m , k G ) ) ) ) O(\frac{d}{k_G}(k_G^2(\max(m,k_G)))) O(kGd(kG2(max(m,kG)))),通常,我们选择 k G < m k_G<m kG<m,此时组白化计算复杂度为 O ( m d k G ) O(mdk_G) O(mdkG)。
3. 实验结果
3.1 消融实验
3.2 CNN和ResNet
在CNN上比较BN与DBN
加在ResNet中加入DBN效果也会提高一部分。
作者在附录B中也讨论了DBN其实比较time costly,主要是ZCA中矩阵求逆过于耗时,所以该组也在CVPR19中提出了采用Iterative Normalization,并在CVPR21上发表了Group Whitening详细探讨了组白化的内容。