1 前言
感谢网友“天泽28”的帮助,
原文链接如下:
https://blog.csdn.net/u012328159/article/details/80311892
2 Nesterov方法的公式
2.1 夏侯南溪采用的Nesterov公式
这里我们最终选择的是根据 Neural Network Libraries文档中公式经过扩展,延伸得到的公式;
我觉得这样的公式是最容易理解的,也是最符合我们的认知的,相应的公式为:
v
t
+
1
=
γ
v
t
−
η
∗
∇
w
t
J
(
w
t
)
w
t
+
1
=
w
t
−
γ
v
t
+
(
1
+
γ
)
∗
v
t
+
1
v_{t+1} = \gamma v_t - \eta*\nabla_{w_t}J(w_t)\\ w_{t+1} = w_t - \gamma v_t + (1 + \gamma)*v_{t+1}
vt+1=γvt−η∗∇wtJ(wt)wt+1=wt−γvt+(1+γ)∗vt+1
我们将梯度下降的过程想象为小球从斜坡上下降的过程,这里的
v
t
v_t
vt指当前时刻t小球的速度,
∇
w
t
J
(
w
t
)
\nabla_{w_t}J(w_t)
∇wtJ(wt)为当前的梯度值、也是小球的外力(可以理解为是重力加速度在当前斜坡方向上的分力),
w
t
w_t
wt为小球在t时刻小球的参数超平面坐标,
η
\eta
η为学习率;
2.2 Sutskever论文中的公式
原始公式出自于 Ilya Sutskever和 Geoffrey Hinton等人的文章《On the importance of initialization and momentum in deep learning》,
(这里我参阅了网友“天泽28”的博文《深度学习中优化方法——momentum、Nesterov Momentum、AdaGrad、Adadelta、RMSprop、Adam》,里面对Nesterov方法及相关资料分析的十分清晰);
而由于这个公式难以实现,可以看到导数内部还存在一个因式“
θ
t
+
μ
v
t
\theta_t + \mu v_t
θt+μvt”,所以在实现时,一般会采用其它的公式;
2.3 PyTorch文档中使用的公式
(由于Nesterov方法有多种不同的实现方法,下面列出的是PyTorch的官方文档中给出的公式,这个公式是我根据PyTorch中原始公式进行相应扩展得出的)
其公式如下:
v
t
+
1
=
μ
∗
v
t
+
α
∗
g
t
p
t
+
1
=
p
t
−
l
r
∗
(
v
t
+
1
+
α
∗
g
t
)
v_{t+1} = \mu*v_{t} + \alpha*g_{t}\\ p_{t+1} = p_t - lr*(v_{t+1} + \alpha*g_{t})
vt+1=μ∗vt+α∗gtpt+1=pt−lr∗(vt+1+α∗gt)
仔细一看,我发现这个跟PyTorch中提供的公式还是有些区别,这里我们来进行一下推导和解释,
首先,需要说明的是:
PyTorch中的
v
P
v^P
vP累积的是正梯度,所以最后在更新参数时是相减,用“
−
-
−号 ”;
Sutskever等人的论文中
v
S
v^S
vS累积的是负梯度,所以最后在更新参数时是相加,用“
+
+
+号 ”;
所以有
v
P
=
−
v
S
v^P=-v^S
vP=−vS
3 Nesterov南溪选择的公式与论文原始公式等价性的证明
对于Nesterov原始公式与其等价公式的证明,之前我一直没有看懂,不太懂两个公式的条件和结论到底分别是什么,
赵老师和其他同学在证明的时候,各种不同的带入推导看似很有道理,
在我看来,却没有什么逻辑性:
一个很重要的原因是,他们在推导的时候并没有明确证明的条件和结论,所以一直很难看懂;还有一点就是,他们选择的公式有点复杂,不容易理解,所以我们这里就不进行赘述了;
这里我们对南溪选择的公式和原始论文公式的等价性进行证明,
已知条件:
我们已知Nesterov算法的两种递推公式,分别为:
Sutskever原始论文的公式(“公式S”),
v
t
+
1
=
μ
∗
v
t
−
ε
∇
f
(
θ
t
+
μ
∗
v
t
)
θ
t
+
1
=
θ
t
+
v
t
+
1
v_{t+1} = \mu*v_{t} - \varepsilon\nabla f\left ( \theta_t+ \mu*v_{t}\right )\\ \theta_{t+1} = \theta_t + v_{t+1}
vt+1=μ∗vt−ε∇f(θt+μ∗vt)θt+1=θt+vt+1
南溪选择的公式(“公式N”),
v
t
+
1
=
γ
v
t
−
η
∗
∇
w
t
J
(
w
t
)
w
t
+
1
=
w
t
−
γ
v
t
+
(
1
+
γ
)
∗
v
t
+
1
v_{t+1} = \gamma v_t - \eta*\nabla_{w_t}J(w_t)\\ w_{t+1} =w_t - \gamma v_t + (1 + \gamma)*v_{t+1}
vt+1=γvt−η∗∇wtJ(wt)wt+1=wt−γvt+(1+γ)∗vt+1
证明:
两个公式是等价的;
解:
我们首先将原始公式进行一些改写,以方便我们进行理解,
v
t
+
1
=
γ
∗
v
t
−
η
Δ
θ
t
+
γ
∗
v
t
J
(
θ
t
+
γ
∗
v
t
)
θ
t
+
1
=
θ
t
+
v
t
+
1
v_{t+1} = \gamma*v_{t} - \eta\Delta_{ \theta_t+ \gamma*v_{t}}J\left ( \theta_t+ \gamma*v_{t}\right )\\ \theta_{t+1} = \theta_t + v_{t+1}
vt+1=γ∗vt−ηΔθt+γ∗vtJ(θt+γ∗vt)θt+1=θt+vt+1
这里导数函数
∇
f
\nabla f
∇f的自变量和求导变量都是因式
θ
t
+
μ
∗
v
t
\theta_t+ \mu*v_{t}
θt+μ∗vt
那么两个公式的等价性该如何证明呢,我们还是要回到问题的本身,
这里的需求是找到一种形如“
θ
j
=
θ
j
−
α
∂
∂
θ
j
J
(
θ
)
\theta_j = \theta_j - \alpha\frac{\partial}{\partial\theta_j}J\left(\theta\right)
θj=θj−α∂θj∂J(θ)”形式的公式,因为原始公式难以实现的原因就是导函数的自变量和求导变量都是一个子式,这是比较麻烦的;
为了能够简化计算,我们最好能将求导函数的自变量整合成一个变量,于是我们想到的方法就是换元;
于是我们令
w
t
=
θ
t
+
γ
∗
v
t
w_t = \theta_t+ \gamma*v_{t}
wt=θt+γ∗vt,
但是还有一个问题,换元之后怎么办呢,
θ
t
\theta_t
θt与
w
t
w_t
wt是什么关系呢,难道我们要使用
w
t
w_t
wt替代参数
θ
t
\theta_t
θt吗,但是我们本来要求的不是模型参数
θ
t
\theta_t
θt的局部最优解吗,怎么可以使用另一个参数代替呢,
别担心,答案是可以的;
在0时刻,有
w
0
=
θ
0
+
γ
∗
v
0
w_0 = \theta_0+ \gamma*v_0
w0=θ0+γ∗v0,而由于初始速度
v
0
=
0
v_0=0
v0=0,则有
w
0
=
θ
0
w_0 = \theta_0
w0=θ0,
而当
t
→
+
∞
t \rightarrow +\infty
t→+∞时,
v
t
v_t
vt无限趋近于0,这一点是可以想到的是因为当小球达到谷底时(也就是无限趋近与当前的局部最优解时),速度会下降到几乎为0
∴有
lim
t
→
+
∞
v
t
=
0
\lim_{t\rightarrow+\infty}v_t=0
limt→+∞vt=0,且
lim
t
→
+
∞
w
t
=
θ
t
\lim_{t\rightarrow+\infty}w_t=\theta_t
limt→+∞wt=θt
由此可知,我们可以使用
w
t
w_t
wt来代替
θ
t
\theta_t
θt进行迭代计算,
于是我们将
w
t
=
θ
t
+
γ
∗
v
t
w_t = \theta_t + \gamma*v_t
wt=θt+γ∗vt带入公式S,则有
v
t
+
1
=
γ
∗
v
t
−
η
∇
w
t
f
(
w
t
)
w
t
+
1
−
γ
∗
v
t
+
1
=
w
t
−
γ
∗
v
t
+
v
t
+
1
v_{t+1} = \gamma*v_{t} - \eta\nabla _{w_t}f\left(w_t\right )\\ w_{t + 1} - \gamma*v_{t + 1} = w_t - \gamma*v_t + v_{t+1}
vt+1=γ∗vt−η∇wtf(wt)wt+1−γ∗vt+1=wt−γ∗vt+vt+1
化简得
v
t
+
1
=
γ
∗
v
t
−
η
∇
w
t
f
(
w
t
)
w
t
+
1
=
w
t
+
(
1
+
γ
)
v
t
+
1
−
γ
∗
v
t
v_{t+1} = \gamma*v_{t} - \eta\nabla _{w_t}f\left(w_t\right )\\ w_{t + 1} = w_t + (1+\gamma)v_{t+1} - \gamma*v_t
vt+1=γ∗vt−η∇wtf(wt)wt+1=wt+(1+γ)vt+1−γ∗vt
与公式N相同,所以原命题得证;