Batch Normalization函数详解及反向传播中的梯度求导

摘要

本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度

相关

配套代码, 请参考文章 :

Python和PyTorch对比实现批标准化Batch Normalization函数及反向传播

本文仅介绍Batch Normalization的训练过程, 测试或推理过程请参考 :

Batch Normalization的测试或推理过程及样本参数更新方法

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. 概念

批标准化 (Batch Normalization) 的思想来自传统的机器学习, 主要为了处理数据取值范围相差过大的问题.
比如, 正常成年人每升血液中所含血细胞的数量:

项目 数量
红细胞计数 RBC 3.5 × 1 0 12 ∼ 5.5 × 1 0 12 3.5×10^{12} \sim 5.5×10^{12} 3.5×10125.5×1012
白细胞计数 WBC 5.0 × 1 0 9 ∼ 10.0 × 1 0 9 5.0×10^9 \sim 10.0×10^9 5.0×10910.0×109
血小板计数 PLT 1.5 × 1 0 11 ∼ 3.5 × 1 0 11 1.5×10^{11} \sim 3.5×10^{11} 1.5×10113.5×1011
血红蛋白 Hb 110 ∼ 160 g / L 110 \sim 160g/L 110160g/L

如果这些指标发生异常变化, 人体就可能患病.
这些数据不仅量级差别非常大, 血红蛋白的单位还和其他项目不一样, 不可能直接用于机器学习.
传统的标准化方法 (Normalization) 是将这些数据统一缩放为 0 ~ 1 之间的数据.

深度神经网络学习过程中的 Batch Normalization 与之类似, 不同点在于数据规模非常大, 只能分批处理, 故称为批标准化.

2. 定义

批标准化是对同一个指标下的数据进行处理的, 与其他指标无关.
将同一个项目下的数据用向量 x 表示:
x = ( x 1 , x 2 , x 2 , ⋯   , x k ) x = (x_1,x_2,x_2,\cdots,x_k) x=(x1,x2,x2,,xk)

均值 m m m 及方差 v v v 是标量 :
m = ∑ t = 1 k x t / n    v = ∑ t = 1 k ( x t − m ) 2 / n m=\sum_{t=1}^{k}x_{t}/n\\ \;\\ v =\sum_{t=1}^{k} (x_{t} - m)^2/n m=t=1kxt/nv=t=1k(xtm)2/n

为防止分母为零, 设一个极小数 ε \varepsilon ε, 如 ε = 1 0 − 5 \varepsilon=10^{-5} ε=105, 则数据标准化为 :
s i = x i − m v + ε s_{i} = \frac{x_{i} - m}{\sqrt{v + \varepsilon}} si=v+ε xim

为了增强数据的表征力, 添加一个线性变换, 得 :
y i = w ⋅ s i + b    y i    为    x i    经 过    B a t c h N o r m a l i z a t i o n    转 换 后 的 数 据    w    和    b    是 标 量 , 对 本 批 次 本 指 标 中 所 有 s i 是 相 同 的 y_i =w \cdot s_i + b\\ \;\\ y_i \;为\;x_i\;经过\;Batch Normalization\;转换后的数据\\ \;\\ w \;和\;b\;是标量, 对本批次本指标中所有 s_i 是相同的 yi=wsi+byixiBatchNormalizationwb,si

3. 训练过程中的反向传播的梯度

3.1 误差 e 对 x 的梯度

考虑一个 k 维输入向量 x , 经 Batch Normalization 得到向量 y, 往前 forward 传播得到误差值 error (标量 e ). 上游的误差梯度向量 ∇ e ( y ) \nabla e_{(y)} e(y) 已在反向传播时得到, 求 e 对 x 的梯度.

已知 :
e = f o r w a r d ( y )    ∇ e ( y ) = d e d y = ( ∂ e y 1 , ∂ e y 2 , ∂ e y 3 , ⋯   , ∂ e y k )    m = ∑ t = 1 k x t / k    v = ∑ t = 1 k ( x t − m ) 2 / k    s i = x i − m v + ε    y i = w ⋅ s i + b e=forward(y)\\ \;\\ \nabla e_{(y)}=\frac{de}{dy}=(\frac{\partial e}{y_1}, \frac{\partial e}{y_2}, \frac{\partial e}{y_3}, \cdots, \frac{\partial e}{y_k} )\\ \;\\ m=\sum_{t=1}^{k}x_{t}/k\\ \;\\ v =\sum_{t=1}^{k} (x_{t} - m)^2/k\\ \;\\ s_{i} = \frac{x_{i} - m}{\sqrt{v + \varepsilon}}\\ \;\\ y_i =w \cdot s_i + b\\ e=forward(y)e(y)=dyde=(y1e,y2e,y3e,,

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Batch Normalization反向传播可以用链式法则来推导。下面是一个简单的推导过程: 假设输入为$x$,BN层的输出为$y$,其归一化后的值为$\hat{y}$,缩放和移位后的输出为$z$,BN层的参数为$\gamma$和$\beta$,损失函数为$L$。 则有: $$\hat{y}=\frac{x-\mu_B}{\sqrt{\sigma_B^2+\epsilon}}$$ $$y=\gamma\hat{y}+\beta$$ 其,$\mu_B$和$\sigma_B^2$分别是批量的均值和方差,$\epsilon$是一个很小的常数,防止分母为零。 BN层的反向传播分为两部分:对$\gamma$和$\beta$的梯度和对输入$x$的梯度。 首先,对$\gamma$和$\beta$的梯度可以直接计算: $$\frac{\partial L}{\partial \gamma}=\sum_{i=1}^n\frac{\partial L}{\partial y_i}\hat{y_i}$$ $$\frac{\partial L}{\partial \beta}=\sum_{i=1}^n\frac{\partial L}{\partial y_i}$$ 接下来,我们需要计算对输入$x$的梯度。我们可以先计算$\frac{\partial L}{\partial \hat{y}}$,然后通过链式法则计算出$\frac{\partial L}{\partial x}$: $$\frac{\partial L}{\partial \hat{y_i}}=\frac{\partial L}{\partial y_i}\gamma$$ $$\frac{\partial L}{\partial \sigma_B^2}=\sum_{i=1}^n\frac{\partial L}{\partial \hat{y_i}}(x_i-\mu_B)(-\frac{1}{2})(\sigma_B^2+\epsilon)^{-\frac{3}{2}}$$ $$\frac{\partial L}{\partial \mu_B}=-\frac{1}{\sqrt{\sigma_B^2+\epsilon}}\sum_{i=1}^n\frac{\partial L}{\partial \hat{y_i}}+\frac{\partial L}{\partial \sigma_B^2}\frac{1}{n}\sum_{i=1}^n-2(x_i-\mu_B)$$ $$\frac{\partial L}{\partial x_i}=\frac{\partial L}{\partial \hat{y_i}}\frac{1}{\sqrt{\sigma_B^2+\epsilon}}+\frac{\partial L}{\partial \sigma_B^2}\frac{2(x_i-\mu_B)}{n}+\frac{\partial L}{\partial \mu_B}\frac{1}{n}$$ 其,$\frac{\partial L}{\partial \mu_B}$和$\frac{\partial L}{\partial \sigma_B^2}$可以通过反向传播递归计算得到。 最后,我们可以使用$\frac{\partial L}{\partial x}$更新网络的权重参数。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值