十一、变分推断
1.背景
{ 频 率 角 度 , 优 化 问 题 { 回 归 { M o d e l 策 略 算 法 { 解 析 解 数 值 解 S V M E M 等 等 贝 叶 斯 角 度 , 积 分 问 题 { 贝 叶 斯 I n f e r n e c e ( 求 贝 叶 斯 后 验 ) P ( θ ∣ x ) = P ( x ∣ θ ) P ( θ ) P ( x ) 贝 叶 斯 决 策 ( 预 测 , 最 后 还 是 求 贝 叶 斯 后 验 ) P ( x ~ ∣ x ) = ∫ θ P ( x ~ , θ ∣ x ) d θ = ∫ θ P ( x ~ ∣ θ ) P ( θ ∣ x ) d θ = E θ ∣ x [ P ( x ~ ∣ θ ) ] \begin{cases} 频率角度,优化问题 \begin{cases} 回归 \begin{cases} Model\\ 策略\\ 算法 \begin{cases} 解析解\\ 数值解 \end{cases} \end{cases}\\ SVM\\ EM\\ 等等 \end{cases}\\ 贝叶斯角度,积分问题 \begin{cases} 贝叶斯Infernece(求贝叶斯后验)\\ P(\theta \mid x)=\frac{P(x \mid \theta)P(\theta)}{P(x)}\\ 贝叶斯决策(预测,最后还是求贝叶斯后验)\\ P(\tilde{x} \mid x)=\int_{\theta}P(\tilde{x},\theta \mid x) d\theta=\int_{\theta}P(\tilde{x} \mid \theta)P(\theta \mid x)d\theta=E_{\theta \mid x}[P(\tilde{x} \mid \theta)] \end{cases} \end{cases} ⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧频率角度,优化问题⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧回归⎩⎪⎪⎪⎨⎪⎪⎪⎧Model策略算法{解析解数值解SVMEM等等贝叶斯角度,积分问题⎩⎪⎪⎪⎨⎪⎪⎪⎧贝叶斯Infernece(求贝叶斯后验)P(θ∣x)=P(x)P(x∣θ)P(θ)贝叶斯决策(预测,最后还是求贝叶斯后验)P(x~∣x)=∫θP(x~,θ∣x)dθ=∫θP(x~∣θ)P(θ∣x)dθ=Eθ∣x[P(x~∣θ)]
I n f e r e n c e { 精 确 推 断 ( 后 验 简 单 ) 近 似 推 断 / 近 似 推 断 的 期 望 ( 参 数 空 间 、 隐 变 量 非 常 复 杂 ) { 确 定 性 近 似 → V I 随 机 近 似 → M C M C , M H , G i b b s Inference \begin{cases} 精确推断(后验简单)\\ 近似推断/近似推断的期望(参数空间、隐变量非常复杂)\\ \begin{cases} 确定性近似\to VI\\ 随机近似 \to MCMC,MH,Gibbs \end{cases} \end{cases} Inference⎩⎪⎪⎪⎨⎪⎪⎪⎧精确推断(后验简单)近似推断/近似推断的期望(参数空间、隐变量非常复杂){确定性近似→VI随机近似→MCMC,MH,Gibbs
2.公式推导
x
x
x:observed data
z
z
z:later variable + parameter
(
x
,
z
)
(x,z)
(x,z):complete data
ELBO + KL
log
P
(
x
)
=
L
(
q
)
+
K
L
(
q
∣
∣
p
)
\log P(x)= L(q)+KL(q||p)
logP(x)=L(q)+KL(q∣∣p)
q
^
(
z
)
=
arg
max
q
(
z
)
L
(
q
)
→
q
^
(
z
)
≈
p
(
z
∣
x
)
\hat q(z)=\arg \max_{q(z)} L(q) \to \hat q(z) \approx p(z \mid x)
q^(z)=argmaxq(z)L(q)→q^(z)≈p(z∣x)
基于物理的平均场理论
q
(
z
)
=
∏
i
=
1
M
q
i
(
z
i
)
q(z)=\prod_{i=1}^M q_i(z_i)
q(z)=∏i=1Mqi(zi),计算时固定一维
q
j
(
z
j
)
q_j(z_j)
qj(zj)
L
(
q
)
=
∫
z
q
(
z
)
log
P
(
x
,
z
)
d
z
−
∫
z
q
(
z
)
log
q
(
z
)
d
z
L(q)=\int_z q(z) \log P(x,z)dz-\int_z q(z)\log q(z)dz
L(q)=∫zq(z)logP(x,z)dz−∫zq(z)logq(z)dz
∫ z q ( z ) log P ( x , z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log P ( x , z ) d z = ∫ z j q j ( z j ) d z j ( ∫ z i ∏ i M q i ( z i ) log P ( x , z ) d z i ) ( i ≠ j ) = ∫ z j q j ( z j ) d z j E ∏ i M q i ( z i ) [ log P ( x , z ) ] ∫ z q ( z ) log q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log ∏ i = 1 M q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M log q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ log q 1 ( z 1 ) + log q 2 ( z 2 ) + ⋯ + log q M ( z M ) ] d z = ∑ i = 1 M ∫ z i q i ( z i ) log q i ( z i ) d z i = ∫ z j q j ( z j ) log q j ( z j ) d z j + C L ( q ) = ∫ z j q j ( z j ) log p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ∣ ∣ p ^ ( x , z j ) ) ≤ 0 \begin{aligned} \int_z q(z) \log P(x,z)dz &=\int_z \prod_{i=1}^M q_i(z_i) \log P(x,z)dz\\ &=\int_{z_j} q_j(z_j) dz_j\left ( \int_{z_i} \prod_{i}^M q_i(z_i) \log P(x,z) dz_i \right )(i \ne j)\\ &=\int_{z_j} q_j(z_j) dz_j E_{\prod_{i}^{M} q_i(z_i)}[\log P(x,z)]\\ \int_z q(z) \log q(z) dz &=\int_z \prod_{i=1}^{M}q_i(z_i) \log \prod_{i=1}^{M}q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) \sum_{i=1}^{M} \log q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) [ \log q_1(z_1)+\log q_2(z_2)+\cdots+\log q_M(z_M) ] dz\\ &=\sum_{i=1}^{M} \int_{z_i}q_i(z_i) \log q_i(z_i)dz_i\\ &=\int_{z_j}q_j(z_j) \log q_j(z_j)dz_j + C\\ L(q)&=\int_{z_j} q_j(z_j)\log \frac{\hat p(x,z_j)}{q_j(z_j)}dz_j\\ &=-KL(q_j||\hat p(x,z_j)) \le0 \end{aligned} ∫zq(z)logP(x,z)dz∫zq(z)logq(z)dzL(q)=∫zi=1∏Mqi(zi)logP(x,z)dz=∫zjqj(zj)dzj(∫zii∏Mqi(zi)logP(x,z)dzi)(i=j)=∫zjqj(zj)dzjE∏iMqi(zi)[logP(x,z)]=∫zi=1∏Mqi(zi)logi=1∏Mqi(zi)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)logqj(zj)p^(x,zj)dzj=−KL(qj∣∣p^(x,zj))≤0
用EM算法求解含隐变量的极大似然估计,极大似然估计是关于后验概率的函数,将不等式划等号q=p,一般p很复杂不易求得,最小化KL可以得到最优解 q ^ \hat q q^,基于平均场理论,使用相互独立的 ∏ i = 1 M q i ( z i ) \prod_{i=1}^{M}q_i(z_i) ∏i=1Mqi(zi)近似推断后验p
3.再回首
VI (mean field) → \to →Classical VI
Assumption:
q
(
z
)
=
∏
i
=
1
M
q
i
(
z
i
)
q(z)=\prod_{i=1}^{M}q_i(z_i)
q(z)=∏i=1Mqi(zi)
log
q
j
(
z
j
)
=
E
∏
i
q
i
(
z
i
)
[
log
P
θ
(
x
(
i
)
,
z
)
]
+
C
=
∫
q
1
⋯
∫
q
j
−
1
∫
q
j
+
1
⋯
∫
q
M
q
1
⋯
q
j
−
1
q
j
+
1
⋯
q
M
[
log
P
θ
(
x
(
i
)
)
,
z
]
d
q
1
⋯
d
q
j
−
1
d
q
j
+
1
⋯
d
q
M
\begin{aligned} \log q_j(z_j) &=E_{\prod_i q_i(z_i)}[\log P_{\theta}(x^{(i)},z)]+C\\ &=\int_{q_1}\cdots\int_{q_{j-1}} \int_{q_{j+1}} \cdots \int_{q_M}q_1 \cdots q_{j-1}q_{j+1} \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_1\cdots dq_{j-1} dq_{j+1} \cdots dq_M \end{aligned}
logqj(zj)=E∏iqi(zi)[logPθ(x(i),z)]+C=∫q1⋯∫qj−1∫qj+1⋯∫qMq1⋯qj−1qj+1⋯qM[logPθ(x(i)),z]dq1⋯dqj−1dqj+1⋯dqM
目标函数:
q
^
=
arg
min
q
K
L
(
q
∣
∣
p
)
=
arg
max
q
L
(
q
)
q
^
1
(
z
1
)
=
∫
q
2
⋯
∫
q
M
q
2
⋯
q
M
[
log
P
θ
(
x
(
i
)
)
,
z
]
d
q
2
⋯
d
q
M
q
^
2
(
z
2
)
=
∫
q
^
1
⋯
∫
q
M
q
^
1
⋯
q
M
[
log
P
θ
(
x
(
i
)
)
,
z
]
d
q
^
1
⋯
d
q
M
q
^
M
(
z
M
)
=
∫
q
^
1
∫
q
^
2
⋯
∫
q
^
M
−
1
q
^
1
q
^
2
⋯
q
^
M
−
1
[
log
P
θ
(
x
(
i
)
)
,
z
]
d
q
^
1
q
^
2
⋯
d
q
^
M
−
1
\begin{aligned} &\hat q = \arg \min_q KL(q||p)=\arg \max_q L(q)\\ &\hat q_1(z_1) =\int_{q_2} \cdots \int_{q_M}q_2 \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_2 \cdots dq_M\\ &\hat q_2(z_2) =\int_{\hat q_1} \cdots \int_{q_M}\hat q_1 \cdots q_M [\log P_{\theta}(x^{(i)}),z]d\hat q_1 \cdots dq_M\\ &\hat q_M(z_M)=\int_{\hat q_1} \int_{\hat q_2} \cdots \int_{\hat q_{M-1}}\hat q_1 \hat q_2 \cdots \hat q_{M-1} [\log P_{\theta}(x^{(i)}),z]d\hat q_1\hat q_2 \cdots d\hat q_{M-1}\\ \end{aligned}
q^=argqminKL(q∣∣p)=argqmaxL(q)q^1(z1)=∫q2⋯∫qMq2⋯qM[logPθ(x(i)),z]dq2⋯dqMq^2(z2)=∫q^1⋯∫qMq^1⋯qM[logPθ(x(i)),z]dq^1⋯dqMq^M(zM)=∫q^1∫q^2⋯∫q^M−1q^1q^2⋯q^M−1[logPθ(x(i)),z]dq^1q^2⋯dq^M−1
类似于坐标上升梯度上升,收敛终止
Classical VI存在的问题:
- 假设太强
- intractable(依然要求积分)
4.SGVI
随机梯度变分推断
不再求 q ( z ) q(z) q(z)的具体值,假设 q ( z ) q(z) q(z)服从某种分布,求这个分布的参数 ϕ \phi ϕ
BELO
L
(
ϕ
)
=
E
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
q
ϕ
(
z
)
]
ϕ
^
=
arg
max
L
(
ϕ
)
∇
ϕ
L
(
ϕ
)
=
∇
ϕ
E
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
=
∇
ϕ
∫
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
d
z
=
∫
∇
ϕ
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
d
z
+
∫
q
ϕ
(
z
)
∇
ϕ
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
d
z
=
∫
q
ϕ
(
z
)
∇
ϕ
log
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
d
z
−
∫
∇
ϕ
q
ϕ
(
z
)
d
z
=
E
q
ϕ
(
z
)
[
∇
ϕ
log
q
ϕ
(
z
)
(
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
)
]
\begin{aligned} L(\phi)&=E_{q_{\phi}(z)} \left [ \log \frac{P_{\theta}(x^{(i)},z)}{q_{\phi}(z)}\right ]\\ \hat \phi &= \arg \max L(\phi)\\ \nabla_{\phi}L(\phi) &=\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]\\ &=\nabla_{\phi}\int q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int \nabla_{\phi} q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz \\ & \quad +\int q_{\phi}(z) \nabla_{\phi} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int q_{\phi}(z) \nabla_{\phi} \log q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz-\int \nabla_{\phi} q_{\phi}(z)dz\\ &=E_{q_{\phi}(z)} \left [ \nabla_{\phi} \log q_{\phi}(z) \left ( \log P_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right ) \right ] \end{aligned}
L(ϕ)ϕ^∇ϕL(ϕ)=Eqϕ(z)[logqϕ(z)Pθ(x(i),z)]=argmaxL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz=∫∇ϕqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz+∫qϕ(z)∇ϕ[logPθ(x(i),z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz−∫∇ϕqϕ(z)dz=Eqϕ(z)[∇ϕlogqϕ(z)(logPθ(x(i),z)−logqϕ(z))]
因此可以用MC,从
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z)中采样,根据大数定理,用均值近似期望
E
E
E
z
(
l
)
∼
q
ϕ
(
z
)
,
l
=
1
,
2
,
⋯
,
L
z^{(l)} \sim q_{\phi}(z),l=1,2,\cdots,L
z(l)∼qϕ(z),l=1,2,⋯,L
≈
1
L
∑
i
=
1
L
∇
ϕ
log
q
ϕ
(
z
(
l
)
)
log
P
θ
(
x
(
i
)
,
z
(
l
)
−
log
q
ϕ
(
z
(
l
)
)
)
\approx \frac{1}{L} \sum_{i=1}^{L} \nabla_{\phi} \log q_{\phi}(z^{(l)})\log P_{\theta}(x^{(i)},z^{(l)}-\log q_{\phi}(z^{(l)}))
≈L1i=1∑L∇ϕlogqϕ(z(l))logPθ(x(i),z(l)−logqϕ(z(l)))
存在的问题:
在于这部分
∇
ϕ
log
q
ϕ
(
z
)
\nabla_{\phi} \log q_{\phi}(z)
∇ϕlogqϕ(z),当采样到的值接近于0时,在对数log中变化很快(很敏感,方差很大),需要更多的样本,才能比较好的近似;
yon用期望近似
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z)的梯度,而我们的目标函数是
ϕ
^
\hat \phi
ϕ^,因此误差是非常大的。
Reparameterization Trick 重参化技巧
∇
ϕ
L
(
ϕ
)
=
∇
ϕ
E
q
ϕ
(
z
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
\nabla_{\phi}L(\phi) =\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]
∇ϕL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]
期望是关于
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z),
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z)和
ϕ
\phi
ϕ有关系,函数也和
ϕ
\phi
ϕ有关系,导致复杂度很高。为简化问题,假设
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z)和
ϕ
\phi
ϕ没有关系,用一个确定的分布
p
(
ε
)
p(\varepsilon)
p(ε)替代
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z),就可以对直接对函数求导,不用对期望求导。
z
∼
p
ϕ
(
z
∣
x
)
z \sim p_{\phi}(z \mid x)
z∼pϕ(z∣x),引入重参化技巧把
z
z
z和
ϕ
\phi
ϕ的关系解耦。
假设
z
=
g
ϕ
(
ε
,
x
(
i
)
)
,
ε
∼
p
(
ε
)
z=g_{\phi}(\varepsilon, x^{(i)}),\varepsilon \sim p(\varepsilon)
z=gϕ(ε,x(i)),ε∼p(ε),
z
z
z和
ε
\varepsilon
ε为映射关系,各自的积分为1,有如下关系:
∣
p
ϕ
(
z
∣
x
(
i
)
)
d
z
∣
=
∣
p
(
ε
)
d
ε
∣
\left | p_{\phi}(z \mid x^{(i)})dz \right | = \left | p(\varepsilon)d\varepsilon \right |
∣∣∣pϕ(z∣x(i))dz∣∣∣=∣p(ε)dε∣
∇
ϕ
L
(
ϕ
)
=
∇
ϕ
∫
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
q
ϕ
(
z
)
d
z
=
∇
ϕ
∫
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
q
(
ε
)
d
ε
=
∇
ϕ
E
p
(
ε
)
[
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
]
=
E
p
(
ε
)
[
∇
ϕ
(
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
)
)
]
=
E
p
(
ε
)
[
∇
z
(
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
∣
x
(
i
)
)
)
⋅
∇
ϕ
z
]
=
E
p
(
ε
)
[
∇
z
(
log
P
θ
(
x
(
i
)
,
z
)
−
log
q
ϕ
(
z
∣
x
(
i
)
)
)
⋅
∇
ϕ
g
ϕ
(
ε
,
x
(
i
)
)
]
\begin{aligned} \nabla_{\phi}L(\phi) &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q_{\phi}(z) dz\\ &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q(\varepsilon) d\varepsilon\\ &=\nabla_{\phi} E_{p(\varepsilon)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] \\ &=E_{p(\varepsilon)} \left [ \nabla_{\phi} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ) \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi}z \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi} g_{\phi}(\varepsilon, x^{(i)}) \right ] \end{aligned}
∇ϕL(ϕ)=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]q(ε)dε=∇ϕEp(ε)[logPθ(x(i),z)−logqϕ(z)]=Ep(ε)[∇ϕ(logPθ(x(i),z)−logqϕ(z))]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕz]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕgϕ(ε,x(i))]