Rezende D., Mohamed S. Variational Inference with Normalizing Flow. ICML, 2015.
概
VAE的先验分布很重要, 但是后验分布也很重要, 我们常常假设
q
ϕ
(
z
∣
x
)
q_{\phi}(z|x)
qϕ(z∣x)满足一个高斯分布, 这就大大限制了近似后验分布的逼近的准确性.
这番假设实在是过于强烈了.
本文提出的 normalizing flows的方法可以提高
q
ϕ
q_{\phi}
qϕ的逼近能力.
主要内容
首先, 假设我们得到了
q
0
(
z
0
∣
x
)
q_{0}(z_0|x)
q0(z0∣x)(通过重采样得到
z
z
z), 此时我们通过一个可逆函数
f
f
f, 得到
z
1
=
f
(
z
0
)
,
z_1 = f(z_0),
z1=f(z0),
则
z
1
z_1
z1的分布满足:
q
(
z
1
)
=
q
(
z
0
)
∣
d
e
t
∇
z
f
−
1
∣
=
q
(
z
0
)
∣
d
e
t
∇
f
∣
−
1
.
q(z_1) = q(z_0) |\mathrm{det} \nabla_z f^{-1}| = q(z_0) |\mathrm{det} \nabla f|^{-1}.
q(z1)=q(z0)∣det∇zf−1∣=q(z0)∣det∇f∣−1.
以此类推可得:
z
K
=
f
K
∘
⋯
∘
f
2
∘
f
1
(
z
0
)
,
ln
q
K
(
z
K
)
=
ln
q
0
(
z
0
)
−
∑
k
=
1
K
ln
∣
d
e
t
∇
z
k
−
1
f
k
∣
.
z_K = f_K \circ \cdots \circ f_2 \circ f_1(z_0), \\ \ln q_K(z_K) = \ln q_0(z_0) - \sum_{k=1}^K \ln |\mathrm{det} \nabla_{z_{k-1}} f_k|.
zK=fK∘⋯∘f2∘f1(z0),lnqK(zK)=lnq0(z0)−k=1∑Kln∣det∇zk−1fk∣.
也就是说, 只要我们能计算出Jacobian行列式, 那么后验分布的近似能力就大大提高了.
此时ELBO的负数形式为:
F
(
x
)
=
E
q
ϕ
(
z
∣
x
)
[
ln
q
ϕ
(
z
∣
x
)
−
ln
p
θ
(
x
,
z
)
]
=
E
q
0
(
z
0
)
[
ln
q
K
(
z
K
)
−
ln
p
θ
(
x
,
z
K
)
]
=
E
q
0
(
z
0
)
[
ln
q
0
(
z
0
)
]
−
E
q
0
(
z
0
)
[
∑
k
=
1
K
ln
∣
d
e
t
∇
z
k
−
1
f
k
∣
]
+
E
q
0
(
z
0
)
[
ln
p
θ
(
x
,
z
K
)
]
.
\begin{array}{ll} \mathcal{F}(x) &= \mathbb{E}_{q_{\phi}(z|x)}[\ln q_{\phi}(z|x) - \ln p_{\theta}(x,z)] \\ &= \mathbb{E}_{q_{0}(z_0)}[\ln q_{K}(z_K) - \ln p_{\theta}(x,z_K)] \\ &= \mathbb{E}_{q_0(z_0)}[\ln q_0(z_0)] - \mathbb{E}_{q_0(z_0)}[\sum_{k=1}^K\ln |\mathrm{det} \nabla_{z_{k-1}} f_k|] \\ & + \mathbb{E}_{q_0(z_0)} [\ln p_{\theta}(x,z_K)]. \end{array}
F(x)=Eqϕ(z∣x)[lnqϕ(z∣x)−lnpθ(x,z)]=Eq0(z0)[lnqK(zK)−lnpθ(x,zK)]=Eq0(z0)[lnq0(z0)]−Eq0(z0)[∑k=1Kln∣det∇zk−1fk∣]+Eq0(z0)[lnpθ(x,zK)].
注: 因为最后一项和 q K q_K qK无关, 可以由采样直接近似.
一些合适的可逆变换
f
(
z
)
=
z
+
u
h
(
w
T
z
+
b
)
,
f(z) = z + u h(w^Tz + b),
f(z)=z+uh(wTz+b),
其中
h
h
h是一个非线性的激活函数. 则
ψ
(
z
)
=
h
′
(
w
T
z
+
b
)
w
∣
d
e
t
∇
z
f
∣
=
∣
1
+
u
T
ψ
(
z
)
∣
.
\psi(z) = h'(w^Tz+b)w \\ |\mathrm{det} \nabla_z f| = |1 + u^T \psi(z)|.
ψ(z)=h′(wTz+b)w∣det∇zf∣=∣1+uTψ(z)∣.
f
(
z
)
=
z
+
β
⋅
h
(
α
,
γ
)
(
z
−
z
0
)
,
γ
=
∣
z
−
z
0
∣
,
h
(
α
,
γ
)
=
1
/
(
α
+
γ
)
.
f(z) = z + \beta \cdot h(\alpha, \gamma)(z-z_0), \\ \gamma = |z - z_0|, h(\alpha, \gamma) = 1/ (\alpha + \gamma).
f(z)=z+β⋅h(α,γ)(z−z0),γ=∣z−z0∣,h(α,γ)=1/(α+γ).
此时
∣
d
e
t
∇
z
f
∣
=
[
1
+
β
h
(
α
,
γ
)
]
d
−
1
[
1
+
β
h
(
α
,
γ
)
+
β
h
′
(
α
,
γ
)
γ
]
.
|\mathrm{det} \nabla_z f| = [1 + \beta h(\alpha, \gamma)]^{d-1}[1+\beta h(\alpha, \gamma) + \beta h'(\alpha, \gamma) \gamma].
∣det∇zf∣=[1+βh(α,γ)]d−1[1+βh(α,γ)+βh′(α,γ)γ].
其中
d
d
d是
z
z
z的维度.