权重衰减(weight decay)在贝叶斯推断(Bayesian inference)下的理解

权重衰减(weight decay)在贝叶斯推断(Bayesian inference)下的理解

摘要

对于有过拟合的模型,我们经常会用权重衰减(weight decay)这样一种正则化(regularization)的方法。直观上,权重衰减就是在原损失函数的基础上加入了一个对权重模(norm)的惩罚项。这个惩罚项的加入使我们可以权衡模型的灵活性(权重值的绝对值越大,模型越灵活)与稳定性(权重值的绝对值越小,模型越稳定)。那么权重衰减与贝叶斯推断是什么关系呢?本文就来简单介绍下贝叶斯视角下的权重衰减的理解。

权重衰减

假设我们模型的参数是 W W W W W W 是一个向量)。假设模型的损失函数是 J ( W ) \displaystyle J(W) J(W)。加入权重衰减之后,损失函数就变成了 J ( W ) + λ ∑ i w i 2 \displaystyle J(W) + \lambda \sum_i w_i^2 J(W)+λiwi2。为了不失一般性,我们将 J ( W ) \displaystyle J(W) J(W) 写成 ∑ j ( y j − f ( W , x j ) ) 2 \displaystyle \sum_j \big( y_j - f(W, x_j) \big)^2 j(yjf(W,xj))2。 其中 y j y_j yj 是第 j j j 个训练数据对应的真实的 y y y 值。 λ \lambda λ 是一个超参数,可以通过交叉验证 (cross validation) 来确定。模型训练就是要找到最佳的参数 W W W,使得损失函数最小。

贝叶斯(Bayes inference) 视角下的权重衰减

在贝叶斯视角下如何去看权重衰减呢?

假设我们的训练数据集是 D D D,权重是 W W W。我们用条件概率 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(WD) 表示在观测到数据 D D D 的条件下,模型中参数为 W W W 的概率。

根据 Bayes 公式,我们可以将 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(WD) 写成
P ( W ∣ D ) = P ( W ) ⋅ P ( D ∣ W ) P ( D ) \displaystyle \mathbb{P}(W \vert D) = \frac{\mathbb{P}(W) \cdot \mathbb{P}(D \vert W)}{\mathbb{P}(D)} P(WD)=P(D)P(W)P(DW)

从贝叶斯推断的角度去考虑模型的参数选择,我们就是希望找到模型的参数 W W W,使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(WD) 最大。

在上面的公式右边的项中, P ( D ) \displaystyle \mathbb{P}( D) P(D) 表示观测到数据 D D D 的概率,这个概率是通过对所有的权重 W W W 可能取到的值的积分而得,所以与 W W W 无关。

对于 P ( W ) \displaystyle \mathbb{P}(W) P(W),即权重的先验分布。我们假设 W W W 服从正态分布,可以写成 P ( W ) = 1 2 π σ w 2 e − w 2 2 σ w 2 \displaystyle \mathbb{P}(W) = \frac{1}{\sqrt{2 \pi \sigma_w^2}} e^{-\frac{w^2}{2 \sigma_w^2}} P(W)=2πσw2 1e2σw2w2

而对于 P ( D ∣ W ) \displaystyle \mathbb{P}(D \vert W) P(DW),它表示的是在给定模型的参数的时候,观察到训练数据集的条件概率。在这里我们做一个假设,即在给定输入数据 input 以及模型参数 W W W 的时候,准确的 y y y 的值的分布是一个正态分布。我们记准确的 y y y 的值是 t c t_c tc,那么我们的假设可以表示为 p ( t c ∣ y c ) = 1 2 π σ D 2 e − ( t c − y c ) 2 2 σ D 2 \displaystyle p(t_c \vert y_c) = \frac{1}{\sqrt{2 \pi \sigma_D^2}} e^{-\frac{(t_c - y_c)^2}{2 \sigma_D^2}} p(tcyc)=2πσD2 1e2σD2(tcyc)2

似然函数(log likelihood)

我们的目的是要求使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(WD) 最大的权重参数 W W W。因为 log 函数是单调的,所以取 log 之后不改变结果。从而,我们有 arg max ⁡ W P ( W ∣ D ) = arg max ⁡ W ( log ⁡ ( P ( W ∣ D ) ) ) \displaystyle \argmax_{W} \mathbb{P}(W \vert D) = \argmax_{W} \left( \log \left( \mathbb{P}(W \vert D) \right) \right) WargmaxP(WD)=Wargmax(log(P(WD)))。而

log ⁡ ( P ( W ∣ D ) ) = log ⁡ P ( W ) + log ⁡ P ( D ∣ W ) − log ⁡ P ( D ) \log\left( \mathbb{P}(W \vert D) \right) = \log \mathbb{P}(W) + \log \mathbb{P}(D \vert W) - \log \mathbb{P}(D) log(P(WD))=logP(W)+logP(DW)logP(D)

因为 log ⁡ P ( D ) \displaystyle \log \mathbb{P}(D) logP(D) W W W 无关,于是我们有
arg max ⁡ W ( log ⁡ ( P ( W ∣ D ) ) ) = arg max ⁡ W ( log ⁡ P ( W ) + log ⁡ P ( D ∣ W ) ) \displaystyle \argmax_{W} \big( \log \left( \mathbb{P}(W \vert D) \right) \big) = \argmax_{W} \big( \log \mathbb{P}(W) + \log \mathbb{P}(D \vert W) \big) Wargmax(log(P(WD)))=Wargmax(logP(W)+logP(DW))

根据之前的分析,我们可以把。 log ⁡ P ( W ) \displaystyle \log P(W) logP(W) 写成 − w 2 2 σ w 2 − log ⁡ ( 2 π ) − log ⁡ ( σ w ) \displaystyle -\frac{w^2}{2 \sigma_w^2} - \log (\sqrt{2 \pi}) - \log(\sigma_w) 2σw2w2log(2π )log(σw)

log ⁡ ( P ( D ∣ W ) ) = − ( t c − y c ) 2 2 σ D 2 − log ⁡ ( 2 π ) − log ⁡ ( σ D ) \displaystyle \log(\mathbb{P}(D \vert W)) = -\frac{(t_c - y_c)^2}{2 \sigma_D^2} - \log (\sqrt{2 \pi}) - \log (\sigma_D) log(P(DW))=2σD2(tcyc)2log(2π )log(σD)

我们求使得 P ( W ∣ D ) \displaystyle \mathbb{P}(W \vert D) P(WD) 最大的权重参数 W W W,与求使得 − log ⁡ ( P ( W ∣ D ) ) \displaystyle -\log \big( \mathbb{P}(W \vert D) \big) log(P(WD)) 最小的 W W W 是一样的。省略掉常数项,同时省略掉 log ⁡ ( σ w ) \displaystyle \log(\sigma_w) log(σw) log ⁡ ( σ D ) \displaystyle \log(\sigma_D) log(σD) 的项。我们就有

− log ⁡ ( P ( W ∣ D ) ) = ∑ ( t c − y c ) 2 2 σ D 2 + ∑ w 2 2 σ w 2 -\log \big( \mathbb{P}(W \vert D) \big) = \sum \frac{(t_c - y_c)^2}{2 \sigma_D^2} + \sum \frac{w^2}{2 \sigma_w^2} log(P(WD))=2σD2(tcyc)2+2σw2w2

可以看出,这个表达式与我们把权重衰减当作惩罚项去优化的表达式是一样的。

于是我们发现,在一定的假设条件下,贝叶斯推断下去找 W W W,使得 − log ⁡ ( P ( W ∣ D ) ) \displaystyle -\log \big( \mathbb{P}(W \vert D) \big) log(P(WD)) 最小,与把权重衰减当作惩罚项,得到的表达式是一样的。

参考资料

[1] Hinton deep learning 课程相关视频

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值