变分推断
我们已经知道概率模型可以分为,频率派的优化问题和贝叶斯派的积分问题。从贝叶斯角度来看推断,对于
x
^
\hat{x}
x^ 这样的新样本,需要得到:
p
(
x
^
∣
X
)
=
∫
θ
p
(
x
^
,
θ
∣
X
)
d
θ
=
∫
θ
p
(
θ
∣
X
)
p
(
x
^
∣
θ
,
X
)
d
θ
p(\hat{x}|X)=\int_\theta p(\hat{x},\theta|X)d\theta=\int_\theta p(\theta|X)p(\hat{x}|\theta,X)d\theta
p(x^∣X)=∫θp(x^,θ∣X)dθ=∫θp(θ∣X)p(x^∣θ,X)dθ
如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。
我们看到,推断问题的中心是参数后验分布的求解,推断分为:
- 精确推断
- 近似推断-参数空间无法精确求解
- 确定性近似-如变分推断
- 随机近似-如 MCMC,MH,Gibbs
基于平均场假设的变分推断
我们记
Z
Z
Z 为隐变量和参数的集合,
Z
i
Z_i
Zi 为第
i
i
i 维的参数,于是,回顾一下 EM 中的推导:
log
p
(
X
)
=
log
p
(
X
,
Z
)
−
log
p
(
Z
∣
X
)
=
log
p
(
X
,
Z
)
q
(
Z
)
−
log
p
(
Z
∣
X
)
q
(
Z
)
\log p(X)=\log p(X,Z)-\log p(Z|X)=\log\frac{p(X,Z)}{q(Z)}-\log\frac{p(Z|X)}{q(Z)}
logp(X)=logp(X,Z)−logp(Z∣X)=logq(Z)p(X,Z)−logq(Z)p(Z∣X)
左右两边分别积分:
L
e
f
t
:
∫
Z
q
(
Z
)
log
p
(
X
)
d
Z
=
log
p
(
X
)
R
i
g
h
t
:
∫
Z
[
log
p
(
X
,
Z
)
q
(
Z
)
−
log
p
(
Z
∣
X
)
q
(
Z
)
]
q
(
Z
)
d
Z
=
E
L
B
O
+
K
L
(
q
,
p
)
Left:\int_Zq(Z)\log p(X)dZ=\log p(X)\\ Right:\int_Z[\log \frac{p(X,Z)}{q(Z)}-\log \frac{p(Z|X)}{q(Z)}]q(Z)dZ=ELBO+KL(q,p)
Left:∫Zq(Z)logp(X)dZ=logp(X)Right:∫Z[logq(Z)p(X,Z)−logq(Z)p(Z∣X)]q(Z)dZ=ELBO+KL(q,p)
第二个式子可以写为变分和 KL 散度的和:
L
(
q
)
+
K
L
(
q
,
p
)
L(q)+KL(q,p)
L(q)+KL(q,p)
由于这个式子是常数,于是寻找
q
≃
p
q\simeq p
q≃p 就相当于对
L
(
q
)
L(q)
L(q) 最大值。
q
^
(
Z
)
=
a
r
g
m
a
x
q
(
Z
)
L
(
q
)
\hat{q}(Z)=\mathop{argmax}_{q(Z)}L(q)
q^(Z)=argmaxq(Z)L(q)
假设
q
(
Z
)
q(Z)
q(Z) 可以划分为
M
M
M 个组(平均场近似):
q
(
Z
)
=
∏
i
=
1
M
q
i
(
Z
i
)
q(Z)=\prod\limits_{i=1}^Mq_i(Z_i)
q(Z)=i=1∏Mqi(Zi)
因此,在
L
(
q
)
=
∫
Z
q
(
Z
)
log
p
(
X
,
Z
)
d
Z
−
∫
Z
q
(
Z
)
log
q
(
Z
)
L(q)=\int_Zq(Z)\log p(X,Z)dZ-\int_Zq(Z)\log{q(Z)}
L(q)=∫Zq(Z)logp(X,Z)dZ−∫Zq(Z)logq(Z) 中,看
p
(
Z
j
)
p(Z_j)
p(Zj) ,第一项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲\int_Zq(Z)\log …
第二项:
∫
Z
q
(
Z
)
log
q
(
Z
)
d
Z
=
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
∑
i
=
1
M
log
q
i
(
Z
i
)
d
Z
\int_Zq(Z)\log q(Z)dZ=\int_Z\prod\limits_{i=1}^Mq_i(Z_i)\sum\limits_{i=1}^M\log q_i(Z_i)dZ
∫Zq(Z)logq(Z)dZ=∫Zi=1∏Mqi(Zi)i=1∑Mlogqi(Zi)dZ
展开求和项第一项为:
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
log
q
1
(
Z
1
)
d
Z
=
∫
Z
1
q
1
(
Z
1
)
log
q
1
(
Z
1
)
d
Z
1
\int_Z\prod\limits_{i=1}^Mq_i(Z_i)\log q_1(Z_1)dZ=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1
∫Zi=1∏Mqi(Zi)logq1(Z1)dZ=∫Z1q1(Z1)logq1(Z1)dZ1
所以:
∫
Z
q
(
Z
)
log
q
(
Z
)
d
Z
=
∑
i
=
1
M
∫
Z
i
q
i
(
Z
i
)
log
q
i
(
Z
i
)
d
Z
i
=
∫
Z
j
q
j
(
Z
j
)
log
q
j
(
Z
j
)
d
Z
j
+
C
o
n
s
t
\int_Zq(Z)\log q(Z)dZ=\sum\limits_{i=1}^M\int_{Z_i}q_i(Z_i)\log q_i(Z_i)dZ_i=\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+Const
∫Zq(Z)logq(Z)dZ=i=1∑M∫Ziqi(Zi)logqi(Zi)dZi=∫Zjqj(Zj)logqj(Zj)dZj+Const
两项相减,令
E
∏
i
≠
j
q
i
(
Z
i
)
[
log
p
(
X
,
Z
)
]
=
log
p
^
(
X
,
Z
j
)
\mathbb{E}_{\prod\limits_{i\ne j}q_i(Z_i)}[\log p(X,Z)]=\log \hat{p}(X,Z_j)
Ei=j∏qi(Zi)[logp(X,Z)]=logp^(X,Zj) 可以得到:
−
∫
Z
j
q
j
(
Z
j
)
log
q
j
(
Z
j
)
p
^
(
X
,
Z
j
)
d
Z
j
≤
0
-\int_{Z_j}q_j(Z_j)\log\frac{q_j(Z_j)}{\hat{p}(X,Z_j)}dZ_j\le 0
−∫Zjqj(Zj)logp^(X,Zj)qj(Zj)dZj≤0
于是最大的
q
j
(
Z
j
)
=
p
^
(
X
,
Z
j
)
q_j(Z_j)=\hat{p}(X,Z_j)
qj(Zj)=p^(X,Zj) 才能得到最大值。我们看到,对每一个
q
j
q_j
qj,都是固定其余的
q
i
q_i
qi,求这个值,于是可以使用坐标上升的方法进行迭代求解,上面的推导针对单个样本,但是对数据集也是适用的。
基于平均场假设的变分推断存在一些问题:
- 假设太强, Z Z Z 非常复杂的情况下,假设不适用
- 期望中的积分,可能无法计算
SGVI
从 Z Z Z 到 X X X 的过程叫做生成过程或译码,反过来的额过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在一些情况下假设太强,同时积分也不一定能算。我们知道,优化方法除了坐标上升,还有梯度上升的方式,我们希望通过梯度上升来得到变分推断的另一种算法。
我们的目标函数:
q
^
(
Z
)
=
a
r
g
m
a
x
q
(
Z
)
L
(
q
)
\hat{q}(Z)=\mathop{argmax}_{q(Z)}L(q)
q^(Z)=argmaxq(Z)L(q)
假定
q
(
Z
)
=
q
ϕ
(
Z
)
q(Z)=q_\phi(Z)
q(Z)=qϕ(Z),是和
ϕ
\phi
ϕ 这个参数相连的概率分布。于是
a
r
g
m
a
x
q
(
Z
)
L
(
q
)
=
a
r
g
m
a
x
ϕ
L
(
ϕ
)
\mathop{argmax}_{q(Z)}L(q)=\mathop{argmax}_{\phi}L(\phi)
argmaxq(Z)L(q)=argmaxϕL(ϕ),其中
L
(
ϕ
)
=
E
q
ϕ
[
log
p
θ
(
x
i
,
z
)
−
log
q
ϕ
(
z
)
]
L(\phi)=\mathbb{E}_{q_\phi}[\log p_\theta(x^i,z)-\log q_\phi(z)]
L(ϕ)=Eqϕ[logpθ(xi,z)−logqϕ(z)],这里
x
i
x^i
xi 表示第
i
i
i 个样本。
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲\nabla_\phi L(\…
这个期望可以通过蒙特卡洛采样来近似,从而得到梯度,然后利用梯度上升的方法来得到参数:
z
l
∼
q
ϕ
(
z
)
E
q
ϕ
[
(
∇
ϕ
log
q
ϕ
)
(
log
p
θ
(
x
i
,
z
)
−
log
q
ϕ
(
z
)
)
]
∼
1
L
∑
l
=
1
L
(
∇
ϕ
log
q
ϕ
)
(
log
p
θ
(
x
i
,
z
)
−
log
q
ϕ
(
z
)
)
z^l\sim q_\phi(z)\\ \mathbb{E}_{q_\phi}[(\nabla_\phi\log q_\phi)(\log p_\theta(x^i,z)-\log q_\phi(z))]\sim \frac{1}{L}\sum\limits_{l=1}^L(\nabla_\phi\log q_\phi)(\log p_\theta(x^i,z)-\log q_\phi(z))
zl∼qϕ(z)Eqϕ[(∇ϕlogqϕ)(logpθ(xi,z)−logqϕ(z))]∼L1l=1∑L(∇ϕlogqϕ)(logpθ(xi,z)−logqϕ(z))
但是由于求和符号中存在一个对数项,于是直接采样的方差很大,需要采样的样本非常多。为了解决方差太大的问题,我们采用 Reparameterization 的技巧。
考虑:
∇
ϕ
L
(
ϕ
)
=
∇
ϕ
E
q
ϕ
[
log
p
θ
(
x
i
,
z
)
−
log
q
ϕ
(
z
)
]
\nabla_\phi L(\phi)=\nabla_\phi\mathbb{E}_{q_\phi}[\log p_\theta(x^i,z)-\log q_\phi(z)]
∇ϕL(ϕ)=∇ϕEqϕ[logpθ(xi,z)−logqϕ(z)]
我们取:
z
=
g
ϕ
(
ε
,
x
i
)
,
ε
∼
p
(
ε
)
z=g_\phi(\varepsilon,x^i),\varepsilon\sim p(\varepsilon)
z=gϕ(ε,xi),ε∼p(ε),于是对后验:
z
∼
q
ϕ
(
z
∣
x
i
)
z\sim q_\phi(z|x^i)
z∼qϕ(z∣xi),有
∣
q
ϕ
(
z
∣
x
i
)
d
z
∣
=
∣
p
(
ε
)
d
ε
∣
|q_\phi(z|x^i)dz|=|p(\varepsilon)d\varepsilon|
∣qϕ(z∣xi)dz∣=∣p(ε)dε∣。代入上面的梯度中:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \nabla_\phi L(…
对这个式子进行蒙特卡洛采样,然后计算期望,得到梯度。