Variational Inference 变分推理
1.AE与VAE
Figure 1: Auto-Encoder.
如Fig.1,作为经典的网络结构之一,Auto-Encoder在深度学习的多个领域中有着出色的表现,依赖于重构误差的反向优化,模型可以学习到数据在低维空间的表示,过滤掉数据的冗余特征,得到细粒度特征。但是,AE学习的是确定性编码(潜在表示经常是确定性的),而不是数据的概率分布,不能直接提供关于数据分布的信息,可能会产生过拟合等问题。
Figure 2.
如Fig.2,为了提升模型的泛化能力,使模型具有理解数据分布和生成新样本的能力,Variational Auto-Encoder横空出世。与AE结构不同的是,VAE的潜在表示 z z z 的生成方式由分布采样得到。通过编码器得到潜在分布的均值 μ \mu μ 与标准差 σ \sigma σ 。在此基础上,加入服从正态分布的噪声 ϵ ∼ N ( 0 , 1 ) \epsilon\sim\mathcal{N}(0,1) ϵ∼N(0,1) ,使得 q ( z ∣ x ) = N ( μ , e x p ( σ ) 2 ) q(z|x)=\mathcal{N}(\mu,{exp(\sigma)}^2) q(z∣x)=N(μ,exp(σ)2) ,同时令 z = μ + e x p ( σ ) × ϵ z=\mu+exp(\sigma)\times\epsilon z=μ+exp(σ)×ϵ 。其中 e x p ( σ ) > 0 exp(\sigma)>0 exp(σ)>0 相当于噪声强度因子,且噪声的添加使得模型更具有抗扰动能力。
在loss function的约束项上,也要最小化 ∑ ( e x p ( σ ) − ( 1 + σ ) + μ 2 ) \sum(exp(\sigma)-(1+\sigma)+\mu^2) ∑(exp(σ)−(1+σ)+μ2) ,令 σ \sigma σ 趋近于0, μ \mu μ 趋近于0,故 e x p ( σ ) exp(\sigma) exp(σ) 趋近于1。由此,使得 q ( z ∣ x ) q(z|x) q(z∣x) 趋近于标准正态分布。注:尽管不同数据训练出的变分自编码器(VAE)可能会使潜在分布接近于标准正态分布,但由于模型内部参数的不同和数据的特征差异,即使手动将标准正态分布生成的潜在变量 z ′ z' z′ 输入到不同数据训练出的解码器中,不同解码器所生成的数据之间也会有很大的差异。
2.变分推理的数学推导
虽然我们通过神经网络得到了 q ( z ∣ x ) = N ( μ , e x p ( σ ) 2 ) q(z|x)=\mathcal{N}(\mu,{exp(\sigma)}^2) q(z∣x)=N(μ,exp(σ)2) ,但是 q ( z ∣ x ) q(z|x) q(z∣x) 就是由 x x x 得到的 z z z 的真实分布吗?定义由 x x x 得到的 z z z 的真实分布为 p ( z ∣ x ) p(z|x) p(z∣x) ,我们通过神经网络学习得到 q ( z ∣ x ) q(z|x) q(z∣x) 来近似 p ( z ∣ x ) p(z|x) p(z∣x) ,确保神经网络生成分布的准确性。我们假设先验分布 p ( z ) p(z) p(z) 与 后验分布 p ( z ∣ x ) p(z|x) p(z∣x) 为正态分布(因为计算KL等公式的便利性、标准正态分布被认为是一种无信息性先验、标准正态分布在潜在空间的均匀性等等),因此,为了拉近 q ( z ∣ x ) q(z|x) q(z∣x) 与 p ( z ∣ x ) p(z|x) p(z∣x) 我们使用KL散度(Kullback-Leibler Divergence)度量两个概率分布之间的差异程度(离散型、连续型):
离散型:
D
K
L
(
P
∥
Q
)
=
∑
i
=
1
n
P
i
l
o
g
(
P
i
Q
i
)
(1)
D_{KL}(P\parallel Q)=\sum_{i=1}^n P_ilog(\frac{P_i}{Q_i}) \tag{1}
DKL(P∥Q)=i=1∑nPilog(QiPi)(1)
其中
P
,
Q
P,Q
P,Q 为离散型随机变量的概率分布律
连续型:
D
K
L
(
P
∥
Q
)
=
∫
−
∞
+
∞
p
(
x
)
l
o
g
p
(
x
)
q
(
x
)
d
x
(2)
D_{KL}\left(P\parallel Q\right)=\int_{-\infty}^{+\infty}p(x)log\frac{p(x)}{q(x)}dx \tag{2}
DKL(P∥Q)=∫−∞+∞p(x)logq(x)p(x)dx(2)
其中
P
,
Q
P,Q
P,Q 为连续型随机变量的概率密度
因此,我们用KL散度衡量
q
(
z
∣
x
)
q(z|x)
q(z∣x) 与
p
(
z
∣
x
)
p(z|x)
p(z∣x) ,最小化下式:
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
z
∣
x
)
d
z
(3)
KL(q(z\mid x)||p(z\mid x))=\int q(z\mid x)\log\frac{q(z\mid x)}{p(z\mid x)}dz \tag{3}
KL(q(z∣x)∣∣p(z∣x))=∫q(z∣x)logp(z∣x)q(z∣x)dz(3)
根据贝叶斯定理:
p
(
z
∣
x
)
=
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
(4)
p(z\mid x)=\frac{p(x\mid z)p(z)}{p(x)} \tag{4}
p(z∣x)=p(x)p(x∣z)p(z)(4)
原式等价于:
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
d
z
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
+
∫
q
(
z
∣
x
)
log
p
(
x
)
d
z
−
∫
q
(
z
∣
x
)
log
[
p
(
x
∣
z
)
p
(
z
)
]
d
z
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
+
log
p
(
x
)
∫
q
(
z
∣
x
)
d
z
−
∫
q
(
z
∣
x
)
log
[
p
(
x
∣
z
)
p
(
z
)
]
d
z
(
其中
:
∫
q
(
z
∣
x
)
d
z
=
1
)
=
log
p
(
x
)
+
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
−
∫
q
(
z
∣
x
)
log
[
p
(
x
∣
z
)
p
(
z
)
]
d
z
=
log
p
(
x
)
+
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
−
∫
q
(
z
∣
x
)
log
[
p
(
x
∣
z
)
p
(
z
)
]
d
z
=
log
p
(
x
)
+
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
−
∫
q
(
z
∣
x
)
log
p
(
x
∣
z
)
d
z
−
∫
q
(
z
∣
x
)
log
p
(
z
)
d
z
=
log
p
(
x
)
+
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
z
)
d
z
−
∫
q
(
z
∣
x
)
log
p
(
x
∣
z
)
d
z
=
log
p
(
x
)
−
E
z
∼
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
+
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
\begin{align*} KL(q(z\mid x)||p(z\mid x))&=\int q(z\mid x)\log\frac{q(z\mid x)}{\frac{p(x\mid z)p(z)}{p(x)}}dz \\ &=\int q(z\mid x)\log q(z\mid x)dz+\int q(z\mid x)\log p(x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\int q(z\mid x)\log q(z\mid x)dz+\log p(x)\int q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \ (其中:\int q(z\mid x)dz=1) \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log p(x\mid z)dz-\int q(z\mid x)\log p(z)dz \\ &=\log p(x)+\int q(z\mid x)\log\frac{q(z\mid x)}{p(z)}dz-\int q(z\mid x) \log p(x\mid z)dz \\ &=\log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]+D_{KL}\left(q(z\mid x)||p(z)\right) \tag{5} \end{align*}
KL(q(z∣x)∣∣p(z∣x))=∫q(z∣x)logp(x)p(x∣z)p(z)q(z∣x)dz=∫q(z∣x)logq(z∣x)dz+∫q(z∣x)logp(x)dz−∫q(z∣x)log[p(x∣z)p(z)]dz=∫q(z∣x)logq(z∣x)dz+logp(x)∫q(z∣x)dz−∫q(z∣x)log[p(x∣z)p(z)]dz (其中:∫q(z∣x)dz=1)=logp(x)+∫q(z∣x)logq(z∣x)dz−∫q(z∣x)log[p(x∣z)p(z)]dz=logp(x)+∫q(z∣x)logq(z∣x)dz−∫q(z∣x)log[p(x∣z)p(z)]dz=logp(x)+∫q(z∣x)logq(z∣x)dz−∫q(z∣x)logp(x∣z)dz−∫q(z∣x)logp(z)dz=logp(x)+∫q(z∣x)logp(z)q(z∣x)dz−∫q(z∣x)logp(x∣z)dz=logp(x)−Ez∼q(z∣x)[logp(x∣z)]+DKL(q(z∣x)∣∣p(z))(5)
由于z是从分布中进行采样得到的,而采样过程是不可导的,而我们需要梯度的反传优化,为了将 Eq. (5) 中
E
z
∼
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
=
∫
q
(
z
∣
x
)
log
p
(
x
∣
z
)
d
z
E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]=\int q(z\mid x) \log p(x\mid z)dz
Ez∼q(z∣x)[logp(x∣z)]=∫q(z∣x)logp(x∣z)dz 中的
z
z
z 消掉,我们使用重参数化技巧(例:给定
Z
∼
N
(
μ
,
σ
2
)
Z\sim\mathcal{N}(\mu,\sigma^2)
Z∼N(μ,σ2) ,
ϵ
∼
N
(
0
,
I
)
\epsilon\sim\mathcal{N}(0,\mathbf{I})
ϵ∼N(0,I) 故将
Z
Z
Z 转化为
Z
=
μ
+
σ
ϵ
Z=\mu+\sigma\epsilon
Z=μ+σϵ)将对
z
z
z 的采样等价于对其分布的均值,标准差的采样。
根据 Eq. (5),我们进行拆分:
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
z
)
d
z
=
∫
q
(
z
∣
x
)
log
q
(
z
∣
x
)
d
z
−
∫
q
(
z
∣
x
)
log
p
(
z
)
d
z
=
∫
N
(
z
;
μ
,
σ
2
)
log
N
(
z
;
μ
,
σ
2
)
d
z
−
∫
N
(
z
;
μ
,
σ
2
)
log
N
(
z
;
0
,
I
)
d
z
=
−
J
2
log
(
2
π
)
−
1
2
∑
j
=
1
J
(
1
+
log
σ
j
2
)
−
(
−
J
2
log
(
2
π
)
−
1
2
∑
j
=
1
J
(
μ
j
2
+
σ
j
2
)
)
=
1
2
∑
j
=
1
J
(
(
μ
j
)
2
+
(
σ
j
)
2
−
1
−
log
(
(
σ
j
)
2
)
)
\begin{align*} D_{KL}\left(q(z\mid x)||p(z)\right)&=\int q(z\mid x)\log\frac{q(z\mid x)}{p(z)}dz \\ &=\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log p(z)dz \\ &=\int\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)\log\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)d\mathbf{z}-\int\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)\log\mathcal{N}(\mathbf{z};\boldsymbol{0},\mathbf{I})d\mathbf{z} \\ &=-\frac J2\log(2\pi)-\frac12\sum_{j=1}^J\left(1+\log\sigma_j^2\right)-(-\frac J2\log(2\pi)-\frac12\sum_{j=1}^J\left(\mu_j^2+\sigma_j^2\right)) \\ &=\frac12\sum_{j=1}^J\left(\left(\mu_j\right)^2+\left(\sigma_j\right)^2-1-\log\left(\left(\sigma_j\right)^2\right)\right) \tag{6} \end{align*}
DKL(q(z∣x)∣∣p(z))=∫q(z∣x)logp(z)q(z∣x)dz=∫q(z∣x)logq(z∣x)dz−∫q(z∣x)logp(z)dz=∫N(z;μ,σ2)logN(z;μ,σ2)dz−∫N(z;μ,σ2)logN(z;0,I)dz=−2Jlog(2π)−21j=1∑J(1+logσj2)−(−2Jlog(2π)−21j=1∑J(μj2+σj2))=21j=1∑J((μj)2+(σj)2−1−log((σj)2))(6)
其中
log
p
(
x
)
−
E
z
∼
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
\log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]
logp(x)−Ez∼q(z∣x)[logp(x∣z)] 等价于MSE(或其他Loss,代表真实值与预测值的损失):
log
p
(
x
)
−
E
z
∼
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
→
M
S
E
=
1
n
∑
i
−
1
n
(
x
i
−
y
i
)
2
\begin{align*} \log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]\to MSE=\frac1n\sum_{i-1}^{n}{(x_i-y_i)^2} \tag{7} \end{align*}
logp(x)−Ez∼q(z∣x)[logp(x∣z)]→MSE=n1i−1∑n(xi−yi)2(7)
因此
L
L
L 如下:
L
=
1
n
∑
i
−
1
n
(
x
i
−
y
i
)
2
+
1
2
∑
j
=
1
J
(
(
μ
j
)
2
+
(
σ
j
)
2
−
1
−
log
(
(
σ
j
)
2
)
)
(8)
L=\frac1n\sum_{i-1}^{n}{(x_i-y_i)^2}+\frac12\sum_{j=1}^J\left(\left(\mu_j\right)^2+\left(\sigma_j\right)^2-1-\log\left(\left(\sigma_j\right)^2\right)\right) \tag{8}
L=n1i−1∑n(xi−yi)2+21j=1∑J((μj)2+(σj)2−1−log((σj)2))(8)
综上所述,我们通过最小化
L
L
L 使神经网络学习得到的
q
(
z
∣
x
)
q(z|x)
q(z∣x) 来近似真实分布
p
(
z
∣
x
)
p(z|x)
p(z∣x) ,使自编码器具有泛化性和生成新样本的能力。