致谢
感谢知乎回答“loss的尺度的影响的本质上取决于优化器”给予我的启发!
1 问题描述
最近我在调参时,想到一个问题:“损失函数loss乘以正数因子a是否等价于学习率lr乘以a呢?”
2 解答
对于梯度优化而言,损失函数loss乘以正数因子a与学习率lr乘以a,与优化器算法optimizer有关。
Optimizer | Equivalency (ls & lr) | Experiment |
---|---|---|
SGD | unknown | |
Adam | unknown | |
AdamW | ✓ | Torch_adamw_ls.ipynb |
3 证明
3.1 朴素SGD:等价
3.2 Adam:不等价,ls失效
这里选择的Adam算法形式是比较简单的算法描述,来自于Adam论文《A Method for Stochastic Optimization》,其公式为
直观解释:根据以上算法流程,当loss乘以尺度
s
s
s时,loss梯度
g
t
g_t
gt增大
s
s
s倍,那么
g
t
2
{g_t}^2
gt2则会扩大
s
2
s^2
s2倍;由于
m
t
m_t
mt是
g
t
g_t
gt的累加,
v
t
v_t
vt是
g
t
2
{g_t}^2
gt2的累加,那么
m
t
m_t
mt会扩大
s
s
s倍,
v
t
v_t
vt会扩大
s
2
s^2
s2倍。由于
m
^
t
\hat{m}_t
m^t与
m
t
m_t
mt线性相关,
v
^
t
\hat{v}_t
v^t与
v
t
v_t
vt线性相关,那么最后一步的梯度更新相当于
θ
t
=
θ
t
−
1
−
α
⋅
s
∗
m
^
t
s
2
∗
v
^
t
+
ϵ
=
θ
t
−
1
−
α
⋅
m
^
t
v
^
t
+
ϵ
/
s
\theta_t=\theta_{t-1}- \alpha \cdot \frac{s\ast\hat{m}_t}{ \sqrt{s^2\ast\hat{v}_t} + \epsilon}=\theta_{t-1}- \alpha \cdot \frac{\hat{m}_t}{ \sqrt{\hat{v}_t} + \epsilon/s}
θt=θt−1−α⋅s2∗v^t+ϵs∗m^t=θt−1−α⋅v^t+ϵ/sm^t
由于
ϵ
/
s
\epsilon/s
ϵ/s为极小值,可以忽略不计,则可以看到ls
没有对梯度更新产生作用,即ls失效;
以上过程可以使用代码进行验证。[]
则原命题等价于:已知函数
h
t
=
α
⋅
m
^
t
/
(
v
^
t
+
ϵ
)
=
h
(
α
,
f
t
)
h_t=\alpha \cdot \hat{m}_t/\left( \sqrt{\hat{v}_t} + \epsilon \right )=h(\alpha,f_t)
ht=α⋅m^t/(v^t+ϵ)=h(α,ft)
对任意正数因子
m
m
m,有函数
h
1
=
h
(
m
α
,
f
t
)
h_1=h(m\alpha,f_t)
h1=h(mα,ft)
以及函数
h
2
=
h
(
α
,
m
f
t
)
h_2=h(\alpha,mf_t)
h2=h(α,mft)
可知
h
2
=
h
(
α
,
m
f
t
)
=
α
⋅
m
^
t
′
/
(
v
^
t
′
+
ϵ
)
=
行
3
等号右
\begin{aligned} h_2 &= h(\alpha,mf_t) \\ &= \alpha \cdot {\hat{m}_t}'/\left( \sqrt{{\hat{v}_t}'} + \epsilon \right ) \\ &= 行3等号右 \\ \end{aligned}
h2=h(α,mft)=α⋅m^t′/(v^t′+ϵ)=行3等号右
则有
h
1
≡
h
2
h_1\equiv h_2
h1≡h2。
3.2 AdamW:不等价, ls失效
在分析之前,我们首先回顾一下AdamW的算法过程,这里我们参考PyTorch-doc中给出的算法描述:
从文档上可以看出,基于AdamW的梯度更新与loss函数的尺度是无关的;