############博主前言#######################
我写这篇文章的目的:
想必很多人听过神经网络中的momentum算法,
但是为啥叫momentum(动量)算法呢?
和物理中的动量有啥关系呢?
网上没有文章能说清楚。
所以这篇学术论文(哥大教授所撰写)的目的就是:
①弹簧振子的连续微分方程
②弹簧振子的近似离散方程
③从②中得到解集的递推公式(其实就是一个收敛的数列)
④把③中的递推公式写成伪代码,嵌入到神经网络中,用以神经网络的权重收敛。
上述就是神经网络的Momentum算法的由来。
###############下面开始正题########################
《On the Momentum Term in Gradient Descent Learning Algorithms》这篇文章的意图是分析神经网络中的optimizer是momentum的时候,为什么会加速权重的收敛,依据是作者原文开头写着:
Although it is well known that such a term greatly improves the speed of learning,
there have been few rigorous studies of its mechanisms.
In this paper, I show that in the limit of continuous time,
the momentum parameter is analogous to the mass of Newtonian particles that move through a viscous medium in a conservative force field
网上99%的文章讲述momentum的都是虚晃一枪,其实需要整整一篇期刊来对背后的数学原理进行详细的论述。
这篇论文通过“弹簧谐振子”来使用数学手段建模,来严格讨论为什么Momentum会加速收敛,以及满足哪些约束条件下会加速。
文章分为两个部分:
1.连续型,建立微分方程讨论Momentum
2.离散型,建立差分方程讨论Momentum(神经网络的momentum对应这一种)
但是作者为了方便大家理解,还是先从第一种,也就是连续型开始讨论Momentum的加速收敛的效果。
整篇文章提到神经网络只有两次,一次是开头,一次是文章结尾的参考文献里面。
我一开始也看晕了,这篇文章到底和神经网络啥关系,不会是作者年纪太大写跑题了吧?
后来我想了一两天才回过神来。
首先注意这篇论文[1]有两种排版,这篇博客记载的是老的排版。
这篇论文的作者是钱宁,哥伦毕业大学的教授,但是已经联系不上了。
这篇博客大概也是百度上唯一一篇从微分方程、差分方程来解释神经网络的momentum优化原理的文章。
#--------------下面开始:Section1:Introduction--------------
无momentum版本 | Momentum版本 |
---|---|
Δ w t = − ϵ ∇ w E ( w t ) ( 1 ) \Delta \mathbf{w}_{t}=-\epsilon \nabla_{\mathrm{w}} E\left(\mathbf{w}_{t}\right)(1) Δwt=−ϵ∇wE(wt)(1) | Δ w t = − ϵ ∇ w E ( w ) + p Δ w t − 1 ( 2 ) \Delta \mathbf{w}_{t}=-\epsilon \nabla_{\mathrm{w}} E(\mathbf{w})+p \Delta \mathbf{w}_{t-1}(2) Δwt=−ϵ∇wE(w)+pΔwt−1(2) |
文中有一句话:
p is the momentum parameter.
-------------------------下面是Section2:A physical Analogy----------------------------------------
首先是两个微分方程进行比较:
无momentum版本 | Momentum版本 |
---|---|
d w d t = − ϵ ▽ w E ( w ) ( 3 ) \frac{d \text{w}}{dt}=-\epsilon\triangledown_{\text{w}}E(\text{w}) (3) dtdw=−ϵ▽wE(w)(3) | m d 2 w d t 2 + μ d w d t = − ▽ w E ( w ) ( 4 ) m\frac{d^2 \text{w}}{dt^2}+μ\frac{d\text{w}}{dt}=-\triangledown_{\text{w}}E(\text{w})(4) mdt2d2w+μdtdw=−▽wE(w)(4) |
注意:表格中的两个微积分方程都是"力平衡方程".(结合牛顿第三定律F=ma来理解)
然后把微积分的momentum版本转化成差分方程(其实就是数列)版本:
m
w
t
+
Δ
t
+
w
t
−
Δ
t
−
2
w
t
Δ
t
2
+
μ
w
t
+
Δ
t
−
w
t
Δ
t
=
−
∇
w
E
(
w
)
(
5
)
m \frac{\mathbf{w}_{t+\Delta t}+\mathbf{w}_{t-\Delta t}-2 \mathbf{w}_{t}}{\Delta t^{2}}+\mu \frac{\mathbf{w}_{t+\Delta t}-\mathbf{w}_{t}}{\Delta t}=-\nabla_{\mathbf{w}} E(\mathbf{w})(5)
mΔt2wt+Δt+wt−Δt−2wt+μΔtwt+Δt−wt=−∇wE(w)(5)
整理得到:
w
t
+
Δ
t
−
w
t
=
−
(
Δ
t
)
2
m
+
μ
Δ
t
∇
w
E
(
w
)
+
m
m
+
μ
Δ
t
(
w
t
−
w
t
−
Δ
t
)
(
6
)
\mathbf{w}_{t+\Delta t}-\mathbf{w}_{t}=-\frac{(\Delta t)^{2}}{m+\mu \Delta t} \nabla_{\mathrm{w}} E(\mathbf{w})+\frac{m}{m+\mu \Delta t}\left(\mathbf{w}_{t}-\mathbf{w}_{t-\Delta t}\right)(6)
wt+Δt−wt=−m+μΔt(Δt)2∇wE(w)+m+μΔtm(wt−wt−Δt)(6)
对照Section1中的Momentum版本,得到:
ϵ
=
(
Δ
t
)
2
m
+
μ
Δ
t
(
7
)
\epsilon=\frac{(\Delta t)^{2}}{m+\mu \Delta t}(7)
ϵ=m+μΔt(Δt)2(7)
p
=
m
m
+
μ
Δ
t
(
8
)
p=\frac{m}{m+\mu \Delta t}(8)
p=m+μΔtm(8)
其中:
m
:
质
点
质
量
m:质点质量
m:质点质量
μ
:
阻
尼
系
数
\mu:阻尼系数
μ:阻尼系数
所以所谓的momentum项,其实就是指的:
m
m
+
μ
Δ
t
(
w
t
−
w
t
−
Δ
t
)
\frac{m}{m+\mu \Delta t}\left(\mathbf{w}_{t}-\mathbf{w}_{t-\Delta t}\right)
m+μΔtm(wt−wt−Δt)
也就是section1中提到的
p
Δ
w
t
−
1
p \Delta \mathbf{w}_{t-1}
pΔwt−1
####################################
这里稍微提一下,在我们的中学课本中,摩擦力的公式是:
m
g
μ
mg\mu
mgμ
但是上面表格里面的微分方程里的摩擦力表达式却是:
μ
d
w
d
t
\mu\frac{dw}{dt}
μdtdw
这是为什么呢?
因为这篇论文里的场景是粘性介质,当小球运动速度为0,那么与粘性介质的摩擦力就是0.
准确地来说:这不叫摩擦力,因为没有“摩擦”的特征,而是一种粘力。
----------------------------------------------------------------------------------
Section 3 . Stability and Covergence Analyses
##########3.1 Continuous Time Case##########
这一节开头先说了系统的总能量是:
E
T
=
1
2
m
d
w
T
d
t
d
w
d
t
+
E
(
w
)
(
9
)
E_{T}=\frac{1}{2} m \frac{\mathrm{d} \mathbf{w}^{T}}{\mathrm{d} t} \frac{\mathrm{d} \mathbf{w}}{\mathrm{d} t}+E(\mathbf{w})(9)
ET=21mdtdwTdtdw+E(w)(9)
这个式子的意思很明显:
等号右侧的第一项是在说绑在弹簧上面的质点动能,
等号右侧的第二项是在说弹簧的弹性势能。
这个式子里面为什么有
w
T
\text{w}^T
wT呢?因为这里需要得到的是一个值,而
w
\text{w}
w和
w
T
\text{w}^T
wT都是矢量,所以为了得到一个值(标量),而不是一个(矢量),这里需要的是
dw
T
d
t
⋅
dw
d
t
\frac{\text{dw}^T}{dt}·\frac{\text{dw}}{dt}
dtdwT⋅dtdw,而不是
dw
d
t
⋅
dw
d
t
\frac{\text{dw}}{dt}·\frac{\text{dw}}{dt}
dtdw⋅dtdw
从这里也可以推断出,这篇论文中的的 w \text{w} w是一个列向量
那么momentum项如何加速收敛呢:
m
d
2
w
d
t
2
+
μ
d
w
d
t
≈
−
H
(
w
−
w
0
)
(
10
)
m \frac{\mathrm{d}^{2} \mathbf{w}}{\mathrm{d} t^{2}}+\mu \frac{\mathrm{d} \mathbf{w}}{\mathrm{d} t} \approx-H\left(\mathbf{w}-\mathbf{w}_{0}\right)(10)
mdt2d2w+μdtdw≈−H(w−w0)(10)
这里为什么有
w
0
\mathbf{w}_{0}
w0呢?
对于弹簧而言这个是弹簧的初始位置,如果弹簧一开始就在恒定点,弹簧就不需要晃来晃去了。
对于神经网络而言,这个是权重的初始值,如果没有初始值,神经网络的激活函数可能一开始就陷入了饱和区,导致难以训练。
Hessian矩阵里面其实就是一大堆的二阶导数,如下:
h
i
,
j
=
∂
2
E
(
w
)
∂
w
i
∂
w
j
∣
w
0
(
11
)
h_{i, j}=\left.\frac{\partial^{2} E(\mathbf{w})}{\partial w_{i} \partial w_{j}}\right|_{\mathbf{w}_{0}}(11)
hi,j=∂wi∂wj∂2E(w)∣∣∣∣w0(11)
作者随后对这个矩阵进行了相似对角化分解:
H
=
Q
K
Q
T
,
Q
Q
T
=
I
(
12
)
H=QKQ^T,QQ^T=I(12)
H=QKQT,QQT=I(12)
K
=
(
k
1
k
2
⋱
k
n
)
(
k
i
>
0
)
(
13
)
K=\left(\begin{array}{cccc}{k_{1}} & {} & {} & {} \\ {} & {k_{2}} & {} & {} \\ {} & {} & {\ddots} & {} \\ {} & {} & {} & {k_{n}}\end{array}\right)\left(k_{i}>0\right)(13)
K=⎝⎜⎜⎛k1k2⋱kn⎠⎟⎟⎞(ki>0)(13)
然后令:
w
′
=
Q
T
w
(
14
)
\text{w}'=Q^T\text{w}(14)
w′=QTw(14)
这个式子的计算结果
w
′
\text{w}'
w′肯定是一个列向量。
可以得到一个阻尼振子的表达式:
m
d
2
w
′
d
t
2
+
μ
d
w
′
d
t
=
−
K
w
′
(
15
)
m \frac{\mathrm{d}^{2} \mathbf{w}^{\prime}}{\mathrm{d} t^{2}}+\mu \frac{\mathrm{d} \mathbf{w}^{\prime}}{\mathrm{d} t}=-K \mathbf{w}^{\prime}(15)
mdt2d2w′+μdtdw′=−Kw′(15)
个人认为,从(10)、(11)这2个式子都是错误的,理由是:
E
(w)
E\text{(w)}
E(w)代表的是弹簧的弹性势能,这个等式是力平衡方程,所以等号右侧应该是满足胡克定律的kx形式,也就是说,对
E
(w)
E\text{(w)}
E(w)求导一次就可以达到目标,而不是两次。
补充:式(26)可以证明这两个式子存在错误.
上面的式(12)~(15)是为了使用对角化来进行解耦,需要的线代知识参考[2]
根据[15]可以得到关于每个弹簧振子的微分方程是:
m
d
2
w
i
′
d
t
2
+
μ
d
w
i
′
d
t
=
−
k
i
w
i
′
(
16
)
m \frac{d^{2} w_{i}^{\prime}}{d t^{2}}+\mu \frac{d w_{i}^{\prime}}{d t}=-k_{i} w_{i}^{\prime}(16)
mdt2d2wi′+μdtdwi′=−kiwi′(16)
在没有momentum term的情况下,
也就是当m=0时,式(16)的解集是:
w
i
′
(
t
)
=
c
e
λ
i
,
0
t
(
17
)
w_{i}^{\prime}(t)=c e^{\lambda_{i, 0} t}(17)
wi′(t)=ceλi,0t(17)
其中c是常数,且
λ
i
,
0
=
−
k
i
μ
(
18
)
\lambda_{i, 0}=-\frac{k_{i}}{\mu}(18)
λi,0=−μki(18)
如果m≠0时,解集则为:
w
i
′
(
t
)
=
c
1
e
λ
i
,
1
t
+
c
2
e
λ
i
,
2
t
(
19
)
w_{i}^{\prime}(t)=c_{1} e^{\lambda_{i, 1} t}+c_{2} e^{\lambda_{i, 2} t}(19)
wi′(t)=c1eλi,1t+c2eλi,2t(19)
其中,
c
1
c_1
c1和
c
2
c_2
c2是常数,并且:
λ
i
,
1
=
−
μ
2
m
+
μ
m
(
μ
4
m
−
k
i
μ
)
(
20
)
\lambda_{i,1}=-\frac{\mu}{2 m} + \sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)}(20)
λi,1=−2mμ+mμ(4mμ−μki)(20)
λ i , 2 = − μ 2 m − μ m ( μ 4 m − k i μ ) ( 20 ) \lambda_{i,2}=-\frac{\mu}{2 m} - \sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)}(20) λi,2=−2mμ−mμ(4mμ−μki)(20)
容易证明:
∣
Re
λ
i
,
1
∣
≤
∣
Re
λ
i
,
2
∣
\left|\operatorname{Re} \lambda_{i, 1}\right| \leq\left|\operatorname{Re} \lambda_{i, 2}\right|
∣Reλi,1∣≤∣Reλi,2∣
这里我们先做下小小的整理:
参数 | 具体的表达式 | 对应的算法 |
---|---|---|
λ i , 0 \lambda_{i,0} λi,0 | λ i , 0 = − k i μ \lambda_{i, 0}=-\frac{k_{i}}{\mu} λi,0=−μki | SGD |
λ i , 1 \lambda_{i,1} λi,1 | λ i , 1 = − μ 2 m + μ m ( μ 4 m − k i μ ) \lambda_{i,1}=-\frac{\mu}{2 m} + \sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)} λi,1=−2mμ+mμ(4mμ−μki) | Momentum算法 |
λ i , 2 \lambda_{i,2} λi,2 | λ i , 2 = − μ 2 m − μ m ( μ 4 m − k i μ ) \lambda_{i,2}=-\frac{\mu}{2 m} - \sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)} λi,2=−2mμ−mμ(4mμ−μki) | Momentum算法 |
看到这里你一定就会想,其实所谓的Momentum不一定速度真的就比SGD来的快.
要想收敛是有条件的,
R
e
s
u
l
t
1
Result\ 1
Result 1如下:
∣ R e λ i , 1 ∣ > ∣ R e λ i , 0 ∣ ( 22 ) |Re\lambda_{i,1}|>|Re \lambda_{i,0}|(22) ∣Reλi,1∣>∣Reλi,0∣(22)
⇒ μ 2 m − μ m ( μ 4 m − k i μ ) > k i μ ( 22 ) ⇒\frac{\mu}{2 m} - \sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)}>\frac{k_{i}}{\mu}(22) ⇒2mμ−mμ(4mμ−μki)>μki(22)
⇒
k
<
μ
2
2
m
(
23
)
⇒k<\frac{\mu^2}{2m}(23)
⇒k<2mμ2(23)
这个式(23)其实是作者忽略了式(22)中的根号中的项得到的,所以其实并不严谨,
下面是我自己的推导:
μ
2
m
−
k
i
μ
>
μ
m
(
μ
4
m
−
k
i
μ
)
\frac{\mu}{2m}-\frac{k_i}{\mu}>\sqrt{\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right)}
2mμ−μki>mμ(4mμ−μki)
⇔ ( μ 2 m − k i μ ) 2 > μ m ( μ 4 m − k i μ ) ⇔(\frac{\mu}{2m}-\frac{k_i}{\mu})^2>\frac{\mu}{m}\left(\frac{\mu}{4 m}-\frac{k_{i}}{\mu}\right) ⇔(2mμ−μki)2>mμ(4mμ−μki)
⇔
(
μ
2
m
)
2
+
(
k
i
μ
)
2
−
k
i
m
>
μ
2
4
m
2
−
k
i
m
⇔(\frac{\mu}{2m})^2+(\frac{k_i}{\mu})^2-\frac{k_i}{m}>\frac{\mu ^2}{4m^2}-\frac{k_i}{m}
⇔(2mμ)2+(μki)2−mki>4m2μ2−mki
⇔
(
k
i
μ
)
2
>
0
⇔(\frac{k_i}{\mu})^2>0
⇔(μki)2>0
我们看到是恒成立的,但是注意,虽然Momentum微分方程的解集的任意一项都收敛得比SGD来的快,但是由于常数项
c
1
c_1
c1和
c
2
c_2
c2是不确定的,所以无法绝对地判定Momentum的速度比SGD快.
因为权重的初始值是随机的,而这个随机值会决定
c
1
c_1
c1和
c
2
c_2
c2的值,而这两个值会影响总共的收敛时间.
上述证明,只能用来表示,收敛中的某个"时间微元",momentum比SGD来的快.
另外注意,Momentum微分方程的解集中,两个项是一起衰减的.
R
e
s
u
l
t
2
Result\ 2
Result 2:
令
α
=
∣
R
e
λ
i
,
1
∣
∣
R
e
λ
i
,
0
∣
(
24
)
令\alpha=\frac{|Re\lambda_{i,1}|}{|Re\lambda_{i,0}|}(24)
令α=∣Reλi,0∣∣Reλi,1∣(24)
这里为什么搞这么个东西呢?
因为是为了比较SGD和Momentum哪个衰减得更快,
由于两者的微分方程的解集(17)与(20)中,
e
e
e的指数都是负数,显然哪个负数的绝对值更大,哪个方案就衰减得更快,所以作者这里想研究的是当
α
为
何
值
时
\alpha为何值时
α为何值时时,可以让Momentum达到最快的收敛速度.
但是由于作者的
R
e
s
u
l
t
1
Result\ 1
Result 1中的证明错误,所以
R
e
s
u
l
t
2
Result\ 2
Result 2也不再具有可信度.
##########3.2 Continuous Time Case##########
先整理下式(6):
w
t
+
Δ
t
−
w
t
=
−
ϵ
∇
w
E
(
w
)
+
p
(
w
t
−
w
t
−
Δ
t
)
\mathbf{w}_{t+\Delta t}-\mathbf{w}_{t}=-\epsilon\nabla_{\mathrm{w}} E(\mathbf{w})+p\left(\mathbf{w}_{t}-\mathbf{w}_{t-\Delta t}\right)
wt+Δt−wt=−ϵ∇wE(w)+p(wt−wt−Δt)
⇒
w
t
+
Δ
t
=
(
1
+
p
)
I
w
t
−
ϵ
∇
w
E
(
w
)
−
p
w
t
−
Δ
t
⇒\mathbf{w}_{t+\Delta t}=(1+p)I\text{w}_t-\epsilon\nabla_{\mathrm{w}} E(\mathbf{w})-p\mathbf{w}_{t-\Delta t}
⇒wt+Δt=(1+p)Iwt−ϵ∇wE(w)−pwt−Δt
然后下面三个式子:
w
t
+
1
=
[
(
1
+
p
)
I
−
ϵ
H
]
w
t
−
p
w
t
−
1
(
26
)
\mathbf{w}_{t+1}=[(1+p) I-\epsilon H] \mathbf{w}_{t}-p \mathbf{w}_{t-1}(26)
wt+1=[(1+p)I−ϵH]wt−pwt−1(26)
w
t
+
1
′
=
[
(
1
+
p
)
I
−
ϵ
H
]
w
t
′
−
p
w
t
−
1
′
(
27
)
\mathbf{w}_{t+1}^{\prime}=[(1+p) I-\epsilon H] \mathbf{w}_{t}^{\prime}-p \mathbf{w}_{t-1}^{\prime}(27)
wt+1′=[(1+p)I−ϵH]wt′−pwt−1′(27)
w
i
,
t
+
1
′
=
[
1
+
p
−
ϵ
k
i
]
w
i
,
t
′
−
p
w
i
,
t
−
1
′
(
28
)
w_{i, t+1}^{\prime}=\left[1+p-\epsilon k_{i}\right] w_{i, t}^{\prime}-p w_{i, t-1}^{\prime}(28)
wi,t+1′=[1+p−ϵki]wi,t′−pwi,t−1′(28)
(
w
i
,
t
′
w
i
,
t
+
1
′
)
=
A
(
w
i
,
t
−
1
′
w
i
,
t
′
)
=
A
t
(
w
i
,
0
′
w
i
,
1
′
)
(
29
)
\left(\begin{array}{c}{w_{i, t}^{\prime}} \\ {w_{i, t+1}^{\prime}}\end{array}\right)=A\left(\begin{array}{c}{w_{i, t-1}^{\prime}} \\ {w_{i, t}^{\prime}}\end{array}\right)=A^{t}\left(\begin{array}{c}{w_{i, 0}^{\prime}} \\ {w_{i, 1}^{\prime}}\end{array}\right)(29)
(wi,t′wi,t+1′)=A(wi,t−1′wi,t′)=At(wi,0′wi,1′)(29)
A
=
(
0
1
−
p
1
+
p
−
ϵ
k
i
)
(
30
)
A=\left(\begin{array}{cc}{0} & {1} \\ {-p} & {1+p-\epsilon k_{i}}\end{array}\right)(30)
A=(0−p11+p−ϵki)(30)
λ
i
,
{
1
,
2
}
=
1
+
p
−
ϵ
k
i
±
(
1
+
p
−
ϵ
k
i
)
2
−
4
p
2
(
31
)
\lambda_{i,\{1,2\}}=\frac{1+p-\epsilon k_{i} \pm \sqrt{\left(1+p-\epsilon k_{i}\right)^{2}-4 p}}{2}(31)
λi,{1,2}=21+p−ϵki±(1+p−ϵki)2−4p(31)
这里稍微说下,为什么要研究矩阵A呢,这个就好像高中的等比数列一样,增益需要小于1,才能确保最终结果会收敛.
论文中没有给出差分方程的解集,自己把解集写在这里:
w
t
=
c
1
λ
1
t
+
c
2
λ
2
t
w_t=c_1\lambda_1^t+c_2\lambda_2^t
wt=c1λ1t+c2λ2t
然后就能理解论文中为啥出来了下面这个式子:
Max
(
∣
λ
i
,
1
∣
,
∣
λ
i
,
2
∣
)
<
1
(
32
)
\operatorname{Max}\left(\left|\lambda_{i, 1}\right|,\left|\lambda_{i, 2}\right|\right)<1(32)
Max(∣λi,1∣,∣λi,2∣)<1(32)
然后是:
λ
i
,
0
=
1
−
ϵ
k
i
(
33
)
\lambda_{i, 0}=1-\epsilon k_{i}(33)
λi,0=1−ϵki(33)
λ
i
,
{
1
,
2
}
≈
{
1
−
ϵ
k
i
1
−
p
p
(
1
+
ϵ
k
i
1
−
p
)
\lambda_{i,\{1,2\}} \approx\left\{\begin{array}{l}{1-\frac{\epsilon k_{i}}{1-p}} \\ {p\left(1+\frac{\epsilon k_{i}}{1-p}\right)}\end{array}\right.
λi,{1,2}≈{1−1−pϵkip(1+1−pϵki)
论文解读到这里为止,因为为了保证(32),得到的(33)和(34)是错误的,理由是(34)中的
λ
\lambda
λ的分母中存在(1-p)显然属于计算错误.
################################论文解读结束#####################
概念对照:
振荡器(可以是电路或者弹簧振子) | 神经网络 |
---|---|
振荡器能量 | 误差函数MSE |
振荡器微分方程中的 L C d 2 u c d t 2 LC\frac{d^2u_c}{dt^2} LCdt2d2uc | 伪代码中的Momentum项 |
能量通过(电阻上的或者摩擦力上的)热能耗散 | 通过不断减小MSE实现拟合、收敛 |
|R很大时,过阻尼,震荡急速衰减(视觉上等于无法起振),稳定|加速神经网络收敛|
|振荡器平衡点|权重w的稳定值|
################差分方程如何转化为Momentum算法#####################
下面来看下,弹簧振子建模后的差分方程如何转化为神经挽留过中的Momentum算法:
首先附上[2]中的Momentum的算法:
x
k
+
1
=
x
k
−
η
∇
f
(
x
k
)
+
ρ
(
x
k
−
x
k
−
1
)
x_{k+1}=x_{k}-\eta \nabla f\left(x_{k}\right)+\rho\left(x_{k}-x_{k-1}\right)
xk+1=xk−η∇f(xk)+ρ(xk−xk−1)
当
ρ
\rho
ρ=0时,上面的式子就是SGD算法
当
ρ
\rho
ρ≠0时,上面的式子就是Momentum算法
然后回顾下式(28)
w
i
,
t
+
1
′
=
[
1
+
p
−
ϵ
k
i
]
w
i
,
t
′
−
p
w
i
,
t
−
1
′
(
28
)
w_{i, t+1}^{\prime}=\left[1+p-\epsilon k_{i}\right] w_{i, t}^{\prime}-p w_{i, t-1}^{\prime}(28)
wi,t+1′=[1+p−ϵki]wi,t′−pwi,t−1′(28)
我们整理下得到 :
w
i
,
t
+
1
′
=
w
i
,
t
′
−
ϵ
k
i
w
i
,
t
′
+
p
(
w
i
,
t
′
−
w
i
,
t
−
1
′
)
w_{i, t+1}^{\prime}=w_{i, t}^{\prime}-\epsilon k_iw^{\prime}_{i,t}+p(w^{\prime}_{i,t}-w^{\prime}_{i,t-1})
wi,t+1′=wi,t′−ϵkiwi,t′+p(wi,t′−wi,t−1′)
当p=0时,上面的式子就是SGD算法
当p≠0时,上面的式子就是Momentum算法
其中的Momentum项就是
p
(
w
i
,
t
′
−
w
i
,
t
−
1
′
)
p(w^{\prime}_{i,t}-w^{\prime}_{i,t-1})
p(wi,t′−wi,t−1′)
另外:
根据[2]的说法,这里的momentum被称为:
Polyak’s momentum
另外,同样阐述神经网络背后的物理机制的文章还有[3][4]
###########################################################################
为什么上面弹簧的差分方程得到的Momentum算法与Andrew的笔记上是不一致的?
上述写法根据[8]
Andrew给出的Momentum算法与tensoflow的momentum.py中的注释给出的公式是一致的:
注意:这个代码的实现是没有问题的,但是注释有问题。但是代码中的gradient其实是动量相关的项。
代码中的momentum其实只是加权平均系数
β
\beta
β。
强调下,momentum相关的算法中不存在“严格的momentum项”,“严格的momentum项”仅仅存在于微分方程近似测差分方程中的二次差分项。
这是第二种momentum形式,这种形式最早来源于[5]
列出[5]中的几个式子如下:
Δ
x
k
=
γ
Δ
x
k
−
1
−
(
1
−
γ
)
α
g
k
(
3
)
\Delta \mathbf{x}_{k}=\gamma \Delta \mathbf{x}_{k-1}-(1-\gamma) \alpha \mathbf{g}_{k}(3)
Δxk=γΔxk−1−(1−γ)αgk(3)
∇
F
(
x
)
=
H
x
+
d
(
4
)
\nabla F(\mathbf{x})=\mathbf{H} \mathbf{x}+\mathbf{d}(4)
∇F(x)=Hx+d(4)
x
k
+
1
−
x
k
=
γ
(
x
k
−
x
k
−
1
)
−
(
1
−
γ
)
α
(
H
x
k
+
d
)
(
5
)
\mathbf{x}_{k+1}-\mathbf{x}_{k}=\gamma\left(\mathbf{x}_{k}-\mathbf{x}_{k-1}\right)-(1-\gamma) \alpha\left(\mathbf{H} \mathbf{x}_{k}+\mathbf{d}\right)(5)
xk+1−xk=γ(xk−xk−1)−(1−γ)α(Hxk+d)(5)
对比上面分析过的[6]中的式(28):
w
i
,
t
+
1
′
−
w
i
,
t
′
=
p
(
w
i
,
t
′
−
w
i
,
t
−
1
′
)
−
ϵ
k
i
w
i
,
t
′
(
28
)
w_{i, t+1}^{\prime}-w_{i, t}^{\prime}=p(w^{\prime}_{i,t}-w^{\prime}_{i,t-1})-\epsilon k_iw^{\prime}_{i,t}(28)
wi,t+1′−wi,t′=p(wi,t′−wi,t−1′)−ϵkiwi,t′(28)
可能你们会有疑问,[6]的式(28)右侧的两个系数是
p
和
ϵ
p和\epsilon
p和ϵ,看起来是相互独立的.
[5]的式(5)等式右侧的两个系数是
γ
和
1
−
γ
\gamma和1-\gamma
γ和1−γ,怎么看上去好像这两个系数是有关系的呢?
答:
可以参考[6]的下面两个式子:
ϵ
=
(
Δ
t
)
2
m
+
μ
Δ
t
(
7
)
\epsilon=\frac{(\Delta t)^{2}}{m+\mu \Delta t}(7)
ϵ=m+μΔt(Δt)2(7)
p
=
m
m
+
μ
Δ
t
(
8
)
p=\frac{m}{m+\mu \Delta t}(8)
p=m+μΔtm(8)
可以发现
p
和
ϵ
p和\epsilon
p和ϵ也是有关系的,并不是相互独立的.
好了,但是和吴恩达的笔记以及tensorflow上面的伪代码还是不一致啊,怎么回事?
然后[7]来了,把等式右侧的两个项系数整合到一起了,相关式子如下:
x
k
+
1
=
x
k
+
α
k
[
r
k
+
γ
k
(
x
k
−
x
k
−
1
)
]
(
4
)
\mathbf{x}_{k+1}=\mathbf{x}_{k}+\alpha_{k}\left[\mathbf{r}_{k}+\gamma_{k}\left(\mathbf{x}_{k}-\mathbf{x}_{k-1}\right)\right](4)
xk+1=xk+αk[rk+γk(xk−xk−1)](4)
其中:
∇
f
(
x
)
=
A
x
−
b
=
:
−
r
\nabla f(\mathbf{x})=\mathbf{A} \mathbf{x}-\mathbf{b}=:-\mathbf{r}
∇f(x)=Ax−b=:−r
此时[7]的式(4)已经与tensorflow中注释中的伪代码完全一致了.
那么[7]的式(4)与上面[6]的式(28)是否一致呢?
[6]的式(28)是做差后递归的形式
[7]的式(4)是递归的形式
我们处理[7]的式(4):
x
k
+
1
−
x
k
=
α
k
[
r
k
+
γ
k
(
x
k
−
x
k
−
1
)
]
(
4
)
\mathbf{x}_{k+1}-\mathbf{x}_{k}=\alpha_{k}\left[\mathbf{r}_{k}+\gamma_{k}\left(\mathbf{x}_{k}-\mathbf{x}_{k-1}\right)\right](4)
xk+1−xk=αk[rk+γk(xk−xk−1)](4)
对比[6]的式(28):
w
i
,
t
+
1
′
−
w
i
,
t
′
=
−
ϵ
k
i
w
i
,
t
′
+
p
(
w
i
,
t
′
−
w
i
,
t
−
1
′
)
(
28
)
w_{i, t+1}^{\prime}-w_{i, t}^{\prime}=-\epsilon k_iw^{\prime}_{i,t}+p(w^{\prime}_{i,t}-w^{\prime}_{i,t-1})(28)
wi,t+1′−wi,t′=−ϵkiwi,t′+p(wi,t′−wi,t−1′)(28)
前面已经讲过:
p
和
ϵ
p和\epsilon
p和ϵ并不是两个互相独立的变量,所以[7]的式(4)提取公因子
α
k
\alpha_k
αk的做法是合理的.
##########################大总结############################
[6]阐述了如何从一个弹簧振子的微分方程逐步推导出适用于神经网络的momentum算法.
momentum算法的写法演化进程:
[6]的式(28)->[5]的式(5)->[7]的式(4)->tensorflow中momentum伪代码形式/andrew讲课笔记中的形式
其中[6]的式(28)是最为早期的形式,能直接用来配合弹簧振子理解物理动量.
##############补充关于momentum指代的具体含义##################
现在momentum学界有两种说法
1.来自弹簧的动量,也就是本文解读的[6]中的差分方程中的二次项,同时也是物理中的动量项mv
2.把指数加权平均系数认为是momentum,参考[9]
[9]也是tensorflow中Momentum实现的依据
其实个人认为应该[6]的说法更加靠谱些
Reference:
[1]RLC串联电路的微分方程
[2]常见的关于momentum的误解(上)
[3]Su et al. A Differential Equation for Modeling Nesterov’s Accelerated Gradient Method: Theory and Insights. 2015
[4]Yang et al. The Physical Systems Behind Optimization Algorithms. 2016
[5]Stability of Steepest Descent With Momentum for Quadratic Functions
[6]On the Momentum Term in Gradient Descen Learning Algorithm
[7]Steepest descent with momentum for quadratic functions is a version of the conjugate gradient method
[8]Gradient Descent with Momentum
[9]On the importance of initialization and momentum in deep learning