权重衰减(weight decay)在贝叶斯推断(Bayesian inference)下的理解
摘要
对于有过拟合的模型,我们经常会用权重衰减(weight decay)这样一种正则化(regularization)的方法。直观上,权重衰减就是在原损失函数的基础上加入了一个对权重模(norm)的惩罚项。这个惩罚项的加入使我们可以权衡模型的灵活性(权重值的绝对值越大,模型越灵活)与稳定性(权重值的绝对值越小,模型越稳定)。那么权重衰减与贝叶斯推断是什么关系呢?本文就来简单介绍下贝叶斯视角下的权重衰减的理解。
权重衰减
假设我们模型的参数是 W W W ( W W W 是一个向量)。假设模型的损失函数是 J ( W ) \displaystyle J(W) J(W)。加入权重衰减之后,损失函数就变成了 J ( W ) + λ ∑ i w i 2 \displaystyle J(W) + \lambda \sum_i w_i^2 J(W)+λi∑wi2。为了不失一般性,我们将 J ( W ) \displaystyle J(W) J(W) 写成 ∑ j ( y j − f ( W , x j ) ) 2 \displaystyle \sum_j \big( y_j - f(W, x_j) \big)^2 j∑(yj−f(W,xj))2。 其中 y j y_j yj 是第 j j j 个训练数据对应的真实的 y y y 值。 λ \lambda λ 是一个超参数,可以通过交叉验证 (cross validation) 来确定。模型训练就是要找到最佳的参数 W W W,使得损失函数最小。
贝叶斯(Bayes inference) 视角下的权重衰减
在贝叶斯视角下如何去看权重衰减呢?
假设我们的训练数据集是 D D D,权重是 W W W。我们用条件概率 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(W∣D) 表示在观测到数据 D D D 的条件下,模型中参数为 W W W 的概率。
根据 Bayes 公式,我们可以将
P
(
W
∣
D
)
\displaystyle \mathbb{P}(W \vert D)
P(W∣D) 写成
P
(
W
∣
D
)
=
P
(
W
)
⋅
P
(
D
∣
W
)
P
(
D
)
\displaystyle \mathbb{P}(W \vert D) = \frac{\mathbb{P}(W) \cdot \mathbb{P}(D \vert W)}{\mathbb{P}(D)}
P(W∣D)=P(D)P(W)⋅P(D∣W)
从贝叶斯推断的角度去考虑模型的参数选择,我们就是希望找到模型的参数 W W W,使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(W∣D) 最大。
在上面的公式右边的项中, P ( D ) \displaystyle \mathbb{P}( D) P(D) 表示观测到数据 D D D 的概率,这个概率是通过对所有的权重 W W W 可能取到的值的积分而得,所以与 W W W 无关。
对于 P ( W ) \displaystyle \mathbb{P}(W) P(W),即权重的先验分布。我们假设 W W W 服从正态分布,可以写成 P ( W ) = 1 2 π σ w 2 e − w 2 2 σ w 2 \displaystyle \mathbb{P}(W) = \frac{1}{\sqrt{2 \pi \sigma_w^2}} e^{-\frac{w^2}{2 \sigma_w^2}} P(W)=2πσw21e−2σw2w2。
而对于 P ( D ∣ W ) \displaystyle \mathbb{P}(D \vert W) P(D∣W),它表示的是在给定模型的参数的时候,观察到训练数据集的条件概率。在这里我们做一个假设,即在给定输入数据 input 以及模型参数 W W W 的时候,准确的 y y y 的值的分布是一个正态分布。我们记准确的 y y y 的值是 t c t_c tc,那么我们的假设可以表示为 p ( t c ∣ y c ) = 1 2 π σ D 2 e − ( t c − y c ) 2 2 σ D 2 \displaystyle p(t_c \vert y_c) = \frac{1}{\sqrt{2 \pi \sigma_D^2}} e^{-\frac{(t_c - y_c)^2}{2 \sigma_D^2}} p(tc∣yc)=2πσD21e−2σD2(tc−yc)2。
似然函数(log likelihood)
我们的目的是要求使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(W∣D) 最大的权重参数 W W W。因为 log 函数是单调的,所以取 log 之后不改变结果。从而,我们有 arg max W P ( W ∣ D ) = arg max W ( log ( P ( W ∣ D ) ) ) \displaystyle \argmax_{W} \mathbb{P}(W \vert D) = \argmax_{W} \left( \log \left( \mathbb{P}(W \vert D) \right) \right) WargmaxP(W∣D)=Wargmax(log(P(W∣D)))。而
log ( P ( W ∣ D ) ) = log P ( W ) + log P ( D ∣ W ) − log P ( D ) \log\left( \mathbb{P}(W \vert D) \right) = \log \mathbb{P}(W) + \log \mathbb{P}(D \vert W) - \log \mathbb{P}(D) log(P(W∣D))=logP(W)+logP(D∣W)−logP(D)
因为
log
P
(
D
)
\displaystyle \log \mathbb{P}(D)
logP(D) 与
W
W
W 无关,于是我们有
arg max
W
(
log
(
P
(
W
∣
D
)
)
)
=
arg max
W
(
log
P
(
W
)
+
log
P
(
D
∣
W
)
)
\displaystyle \argmax_{W} \big( \log \left( \mathbb{P}(W \vert D) \right) \big) = \argmax_{W} \big( \log \mathbb{P}(W) + \log \mathbb{P}(D \vert W) \big)
Wargmax(log(P(W∣D)))=Wargmax(logP(W)+logP(D∣W))
根据之前的分析,我们可以把。 log P ( W ) \displaystyle \log P(W) logP(W) 写成 − w 2 2 σ w 2 − log ( 2 π ) − log ( σ w ) \displaystyle -\frac{w^2}{2 \sigma_w^2} - \log (\sqrt{2 \pi}) - \log(\sigma_w) −2σw2w2−log(2π)−log(σw)。
而 log ( P ( D ∣ W ) ) = − ( t c − y c ) 2 2 σ D 2 − log ( 2 π ) − log ( σ D ) \displaystyle \log(\mathbb{P}(D \vert W)) = -\frac{(t_c - y_c)^2}{2 \sigma_D^2} - \log (\sqrt{2 \pi}) - \log (\sigma_D) log(P(D∣W))=−2σD2(tc−yc)2−log(2π)−log(σD)
我们求使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(W∣D) 最大的权重参数 W W W,与求使得 − log ( P ( W ∣ D ) ) \displaystyle -\log \big( \mathbb{P}(W \vert D) \big) −log(P(W∣D)) 最小的 W W W 是一样的。省略掉常数项,同时省略掉 log ( σ w ) \displaystyle \log(\sigma_w) log(σw) 与 log ( σ D ) \displaystyle \log(\sigma_D) log(σD) 的项。我们就有
− log ( P ( W ∣ D ) ) = ∑ ( t c − y c ) 2 2 σ D 2 + ∑ w 2 2 σ w 2 -\log \big( \mathbb{P}(W \vert D) \big) = \sum \frac{(t_c - y_c)^2}{2 \sigma_D^2} + \sum \frac{w^2}{2 \sigma_w^2} −log(P(W∣D))=∑2σD2(tc−yc)2+∑2σw2w2
可以看出,这个表达式与我们把权重衰减当作惩罚项去优化的表达式是一样的。
于是我们发现,在一定的假设条件下,贝叶斯推断下去找 W W W,使得 − log ( P ( W ∣ D ) ) \displaystyle -\log \big( \mathbb{P}(W \vert D) \big) −log(P(W∣D)) 最小,与把权重衰减当作惩罚项,得到的表达式是一样的。