1 学习链接
参考链接:
[1] 【变分自编码器 VAE 鲁鹏】 https://www.bilibili.com/video/BV1Zq4y1h7Tu?share_source=copy_web&vd_source=7771b17ae75bc5131361e81a50a0c871
[2] http://www.gwylab.com/note-vae.html
[3] https://www.bilibili.com/video/BV15E411w7Pz/?spm_id_from=333.788.recommend_more_video.-1
2 问题记录
2.1 编码器损失函数问题
VAE中编码器的损失函数如下图黄色框:
对该损失的推导过程可参考:http://www.gwylab.com/note-vae.html
从该损失可以知道,当编码器的输出均值
m
i
m_i
mi为0,标准差
σ
i
\sigma_i
σi为1时,损失最小,同时可满足推理公式中,
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
=
0
KL(q(z|x)||P(z))=0
KL(q(z∣x)∣∣P(z))=0。
假设网络最终能训练的很好,不管输入什么x,编码器的输出都是正态分布,那么当解码器按照最大的
p
(
z
∣
x
)
p(z|x)
p(z∣x)采样z时,就都会得到
z
=
0
z=0
z=0,那解码器的输出就都一样。即不管输入是什么,输出都一样,网络就没有意义了。
对于这个问题,我之前在看VAE的推导过程时,忽略了一个点(当然教程也没讲,也可能只有我忽略了),那就是
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
=
0
KL(q(z|x)||P(z))=0
KL(q(z∣x)∣∣P(z))=0是对于一个x而言的,x就是不同的网络输入,在手写数字生成中,不同的mnist图像就是不同的x。
网络训练的最终目的是,使解码器对所有x的输出概率最高,即
∑
x
p
(
x
)
\sum_x p(x)
∑xp(x)最大,所以在一个batch中,实际编码器损失应该是,
arg min
θ
∑
x
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
\argmin_\theta \sum_x KL(q(z|x)||P(z))
argminθ∑xKL(q(z∣x)∣∣P(z)),而不同的
q
(
z
∣
x
)
q(z|x)
q(z∣x)是不可能同时满足
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
=
0
KL(q(z|x)||P(z))=0
KL(q(z∣x)∣∣P(z))=0的,即只要训练网络的x是多个,网络就永远也不可能收敛到对于不同的x都输出正态分布。
所以,基于编码器输入不同x输出的高斯分布,解码器按照最大的 p ( z ∣ x ) p(z|x) p(z∣x)采样z时, z z z是不相同的,这样解码器的输出就不相同。