Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.
主要内容
自编码, 通过引入Encoder和Decoder来估计联合分布 p ( x , z ) p(x,z) p(x,z), 其中 z z z表示隐变量(我们也可以让 z z z为样本标签, 使得Encoder成为一个判别器).
在Decoder中我们建立联合分布
p
θ
(
x
,
z
)
p_{\theta}(x,z)
pθ(x,z)以估计
p
(
x
,
z
)
p(x,z)
p(x,z), 在Encoder中建立一个后验分布
q
ϕ
(
z
∣
x
)
q_{\phi}(z|x)
qϕ(z∣x)去估计
p
θ
(
z
∣
x
)
p_{\theta}(z|x)
pθ(z∣x), 然后极大似然:
log
p
θ
(
x
)
=
log
p
θ
(
x
,
z
)
p
θ
(
z
∣
x
)
=
log
p
θ
(
x
,
z
)
q
ϕ
(
z
∣
x
)
q
ϕ
(
z
∣
x
)
p
θ
(
z
∣
x
)
=
log
p
θ
(
x
,
z
)
q
ϕ
(
z
∣
x
)
+
log
q
ϕ
(
z
∣
x
)
p
θ
(
z
∣
x
)
,
\begin{array}{ll} \log p_{\theta}(x) &= \log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ \end{array},
logpθ(x)=logpθ(z∣x)pθ(x,z)=logqϕ(z∣x)pθ(x,z)pθ(z∣x)qϕ(z∣x)=logqϕ(z∣x)pθ(x,z)+logpθ(z∣x)qϕ(z∣x),
上式俩边关于
z
z
z在分布
q
ϕ
(
z
)
q_{\phi}(z)
qϕ(z)下求期望可得:
log
p
θ
(
x
)
=
E
q
ϕ
(
z
∣
x
)
(
log
p
θ
(
x
,
z
)
q
ϕ
(
z
∣
x
)
+
log
q
ϕ
(
z
∣
x
)
p
θ
(
z
∣
x
)
)
=
E
q
ϕ
(
z
∣
x
)
(
log
p
θ
(
x
,
z
)
q
ϕ
(
z
∣
x
)
)
+
D
K
L
(
q
ϕ
(
z
∣
x
)
∥
p
θ
(
z
∣
x
)
)
≥
E
q
ϕ
(
z
∣
x
)
(
log
p
θ
(
x
,
z
)
q
ϕ
(
z
∣
x
)
)
.
\begin{array}{ll} \log p_{\theta}(x) & = \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)}) \\ &= \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )+D_{KL}(q_{\phi}(z|x)\| p_{\theta}(z |x ))\\ & \ge \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) \end{array}.
logpθ(x)=Eqϕ(z∣x)(logqϕ(z∣x)pθ(x,z)+logpθ(z∣x)qϕ(z∣x))=Eqϕ(z∣x)(logqϕ(z∣x)pθ(x,z))+DKL(qϕ(z∣x)∥pθ(z∣x))≥Eqϕ(z∣x)(logqϕ(z∣x)pθ(x,z)).
既然KL散度非负, 我们极大似然 log p θ ( x ) \log p_{\theta}(x) logpθ(x)可以退而求其次, 最大化 E q ϕ ( z ∣ x ) ( log p θ ( x , z ) q ϕ ( z ∣ x ) ) \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) Eqϕ(z∣x)(logqϕ(z∣x)pθ(x,z))(ELBO, 记为 L \mathcal{L} L).
又(
p
θ
(
z
)
p_{\theta}(z)
pθ(z)为人为给定的先验分布)
L
(
θ
,
ϕ
;
x
)
=
−
D
K
L
(
q
ϕ
(
z
∣
x
)
∥
p
θ
(
z
)
)
+
E
q
ϕ
(
z
∣
x
)
[
log
p
θ
(
x
∣
z
)
]
,
\begin{array}{ll} \mathcal{L}(\theta, \phi; x) &= -D_{KL}(q_{\phi}(z|x)\|p_{\theta}(z))+\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)], \end{array}
L(θ,ϕ;x)=−DKL(qϕ(z∣x)∥pθ(z))+Eqϕ(z∣x)[logpθ(x∣z)],
我们接下来通过对Encoder和Decoder的一些构造进一步扩展上面俩项.
Encoder (损失part1)
Encoder 将 x → z x\rightarrow z x→z, 就相当于在 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)中进行采样, 但是如果是直接采样的话, 就没法利用梯度回传进行训练了, 这里需要一个重参化技巧.
我们假设
q
ϕ
(
z
∣
x
)
q_{\phi}(z|x)
qϕ(z∣x)为高斯密度函数, 即
N
(
μ
,
σ
2
I
)
\mathcal{N}(\mu, \sigma^2 I)
N(μ,σ2I).
注: 文中还提到了其他的一些可行假设.
我们构建一个神经网络
f
f
f, 其输入为样本
x
x
x, 输出为
(
μ
,
log
σ
)
(\mu, \log \sigma)
(μ,logσ)(输出
log
σ
\log \sigma
logσ是为了保证
σ
\sigma
σ为正), 则
z
=
μ
+
ϵ
⊙
σ
,
ϵ
∼
N
(
0
,
I
)
,
z= \mu + \epsilon \odot \sigma, \epsilon \sim \mathcal{N}(0, I),
z=μ+ϵ⊙σ,ϵ∼N(0,I),
其中
⊙
\odot
⊙表示按元素相乘.
注: 我们可以该输出为
(
μ
,
L
)
(\mu, L)
(μ,L)(
L
L
L为三角矩阵, 且对角线元素非负), 而假设
q
ϕ
(
z
∣
x
)
q_{\phi}(z|x)
qϕ(z∣x)的分量不独立, 其协方差函数为
L
T
L
L^TL
LTL, 则
(
z
=
μ
+
L
ϵ
(z=\mu + L \epsilon
(z=μ+Lϵ).
当
p
θ
(
z
)
=
N
(
0
,
I
)
p_{\theta}(z)=\mathcal{N}(0, I)
pθ(z)=N(0,I), 我们可以显示表达出:
Decoder (损失part2)
现在我们需要处理的是第二项, 文中这地方因为直接设计 p θ ( x , z ) p_{\theta}(x,z) pθ(x,z)不容易, 在我看来存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先设计一个网络 g θ ( z ) g_{\theta}(z) gθ(z), 其输出为 x ^ \hat{x} x^, 然后假设 p ( x ∣ x ^ ) p(x|\hat{x}) p(x∣x^)的分布, 第二项就改为近似 E q ϕ ( z ∣ x ) p θ ( x ∣ x ^ ) \mathbb{E}_{q_{\phi}(z|x)}p_{\theta}(x|\hat{x}) Eqϕ(z∣x)pθ(x∣x^).
这么做的好处是显而易见的, 因为Decoder部分, 我们可以通过给定一个 z z z然后获得一个 x ^ \hat{x} x^, 这是很有用的东西, 但是我认为这种不是很合理, 因为除非 g g g是可逆的, 那么 p θ ( x ∣ z ) = p θ ( x ∣ x ^ ) p_{\theta}(x|z)= p _{\theta}(x|\hat{x}) pθ(x∣z)=pθ(x∣x^) (当然, 别无选择).
伯努利分布
此时
x
^
=
g
(
z
)
\hat{x}=g(z)
x^=g(z)是
x
=
1
x=1
x=1的概率, 则此时第二项的损失为
log
p
(
x
∣
x
^
)
=
∑
i
=
1
x
i
log
x
^
i
+
(
1
−
x
i
)
log
(
1
−
x
^
i
)
,
\log p(\mathbf{x}| \hat{\mathbf{x}})= \sum_{i=1} x_i \log \hat{x}_i + (1-x_i) \log (1- \hat{x}_i),
logp(x∣x^)=i=1∑xilogx^i+(1−xi)log(1−x^i),
为(二分类)交叉熵损失.
高斯分布
一种简单粗暴的, p ( x ∣ x ^ ) = N ( x ^ , σ 2 I ) p(x|\hat{x})=\mathcal{N}(\hat{x},\sigma^2 I) p(x∣x^)=N(x^,σ2I), 此时损失为类平方损失, 文中也有别的变换.
代码
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(self, part2):
super(Loss, self).__init__()
self.part2 = part2
def forward(self, mu, sigma, real, fake, lam=1):
part1 = (1 + torch.log(sigma ** 2)
- mu ** 2 - sigma ** 2).sum() / 2
part2 = self.part2(fake, real)
return part1 + lam * part2