概率角度理解VAE结构
1. 从联合概率分布构造的损失函数开始的一切
一个优秀的生成网络应该是怎么样的?这个生成网络在训练过程中,对编码器的要求应该是能够将输入 x x x编码为一对一的隐变量 z z z而不应该是多个 x x x对应着同一个 z z z。另外,在应用场景下的生成过程中,输入了一个处于训练集中隐变量中间位置的新的隐变量 z ′ z^\prime z′给解码器,其生成的输出 x x x应该满足某种在输入训练集中自动学习到的规律(这一点是传统自编码器的弊病,它被许多人诟病的地方在于遇到没有见过的隐变量,网络无法生成一个合乎规律的输出。这样的网络只是一个巨型压缩器,把所有输入的训练数据记忆了起来存储在网络权重中罢了)。比如一堆猫的图片喂给网络吃,吐出来的图片不应该是一张狗的图片或者是五脚猫的图片。
按照这种需求,这个网络的任务应该是逼近输入 x x x和隐变量 z z z的联合概率分布 p ( x , z ) p(x, z) p(x,z),由于 p ( x , z ) = p ( x ∣ z ) p ( z ) = p ( z ∣ x ) p ( x ) p(x, z) = p(x|z)p(z) = p(z|x)p(x) p(x,z)=p(x∣z)p(z)=p(z∣x)p(x)。 如此一来,已知 p ( x ) p(x) p(x),就能够得到编码器 p ( z ∣ x ) p(z|x) p(z∣x); 已知 p ( z ) p(z) p(z),就能够得到解码器 p ( x ∣ z ) p(x|z) p(x∣z)。万物皆概率分布,得概率者采样之后得天下。
因此网络的误差是下面公式的散度,真实的未知的联合概率分布与训练好的网络的联合概率分布之间的散度。下文的推导和理解主要参考了下面的博客,整理并归纳了推导过程:
苏剑林. (Mar. 28, 2018). 《变分自编码器(二):从贝叶斯观点出发 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/5343
1.1 定义
- x x x: 输入
- z z z: 隐变量
- p ( x ) p(x) p(x): 数据 x x x的真实分布(已知的训练数据)
- p ( x , z ) p(x, z) p(x,z): 数据 x x x和隐变量 z z z的真实联合概率分布(确定但未知)
- q ( x , z ) q(x, z) q(x,z): 网络估计的联合概率分布
- p ( x , z ) = p ( x ∣ z ) p ( z ) = p ( z ∣ x ) p ( x ) p(x, z) = p(x|z)p(z) = p(z|x)p(x) p(x,z)=p(x∣z)p(z)=p(z∣x)p(x)
- q ( x , z ) = q ( x ∣ z ) q ( z ) = q ( z ∣ x ) q ( z ) q(x, z) = q(x|z)q(z) = q(z|x)q(z) q(x,z)=q(x∣z)q(z)=q(z∣x)q(z)
1.2 推导过程
D
K
L
(
p
(
x
,
z
)
∣
∣
q
(
x
,
z
)
)
=
∫
∫
p
(
x
,
z
)
l
o
g
p
(
x
,
z
)
q
(
x
,
z
)
d
z
d
x
=
∫
∫
p
(
x
)
[
p
(
z
∣
x
)
l
o
g
p
(
x
)
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
d
x
=
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
x
)
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
=
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
(
p
(
x
)
)
d
z
]
+
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
=
∫
l
o
g
(
p
(
x
)
)
[
∫
p
(
z
∣
x
)
p
(
x
)
d
z
]
d
x
+
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
=
∫
p
(
x
)
l
o
g
(
p
(
x
)
)
d
x
+
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
=
E
x
∼
p
(
x
)
[
l
o
g
(
p
(
x
)
)
]
+
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
\begin{equation} \begin{aligned} & \mathbb{D}_{KL}\left(p(x, z)\big|\big|q(x, z)\right) \\ =&\int\int p(x, z)log\frac{p(x, z)}{q(x, z)}dz dx \\ =&\int\int p(x)\left[p(z|x)log\frac{p(x)p(z|x)}{q(x, z)}dz\right]dx \\ =&\mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(x)p(z|x)}{q(x, z)}dz\right] \\ =&\mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log(p(x))dz\right] + \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x, z)}dz\right] \\ =& \int log(p(x))\bigg[\int p(z|x)p(x)dz\bigg] dx + \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x, z)}dz\right] \\ =& \int p(x)log(p(x)) dx + \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x, z)}dz\right] \\ =& \mathbb{E}_{x\sim p(x)}\bigg[ log(p(x))\bigg] + \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x, z)}dz\right] \\ \end{aligned} \end{equation}
=======DKL(p(x,z)∣
∣∣
∣q(x,z))∫∫p(x,z)logq(x,z)p(x,z)dzdx∫∫p(x)[p(z∣x)logq(x,z)p(x)p(z∣x)dz]dxEx∼p(x)[∫p(z∣x)logq(x,z)p(x)p(z∣x)dz]Ex∼p(x)[∫p(z∣x)log(p(x))dz]+Ex∼p(x)[∫p(z∣x)logq(x,z)p(z∣x)dz]∫log(p(x))[∫p(z∣x)p(x)dz]dx+Ex∼p(x)[∫p(z∣x)logq(x,z)p(z∣x)dz]∫p(x)log(p(x))dx+Ex∼p(x)[∫p(z∣x)logq(x,z)p(z∣x)dz]Ex∼p(x)[log(p(x))]+Ex∼p(x)[∫p(z∣x)logq(x,z)p(z∣x)dz]
右侧第二项可以继续展开:
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
,
z
)
d
z
]
=
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
x
∣
z
)
q
(
z
)
d
z
]
=
E
x
∼
p
(
x
)
[
−
∫
p
(
z
∣
x
)
l
o
g
(
q
(
x
∣
z
)
)
d
z
]
+
E
x
∼
p
(
x
)
[
∫
p
(
z
∣
x
)
l
o
g
p
(
z
∣
x
)
q
(
z
)
d
z
]
=
E
x
∼
p
(
x
)
[
E
z
∼
p
(
z
∣
x
)
[
−
l
o
g
(
q
(
x
∣
z
)
)
]
+
D
K
L
(
p
(
z
∣
x
)
∣
∣
q
(
z
)
)
]
\begin{equation} \begin{aligned} & \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x, z)}dz\right] \\ =& \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(x|z)q(z)}dz\right] \\ =& \mathbb{E}_{x\sim p(x)}\left[-\int p(z|x)log\big(q(x|z)\big)dz\right] + \mathbb{E}_{x\sim p(x)}\left[\int p(z|x)log\frac{p(z|x)}{q(z)}dz\right] \\ =& \mathbb{E}_{x\sim p(x)}\left[\mathbb{E}_{z\sim p(z|x)}\bigg[-log\big( q(x|z) \big)\bigg] + \mathbb{D}_{KL}\bigg(p(z|x)\bigg|\bigg|q(z)\bigg)\right] \end{aligned} \end{equation}
===Ex∼p(x)[∫p(z∣x)logq(x,z)p(z∣x)dz]Ex∼p(x)[∫p(z∣x)logq(x∣z)q(z)p(z∣x)dz]Ex∼p(x)[−∫p(z∣x)log(q(x∣z))dz]+Ex∼p(x)[∫p(z∣x)logq(z)p(z∣x)dz]Ex∼p(x)[Ez∼p(z∣x)[−log(q(x∣z))]+DKL(p(z∣x)∣
∣∣
∣q(z))]
因此总的表达式可以简化为:
D
K
L
(
p
(
x
,
z
)
∣
∣
q
(
x
,
z
)
)
=
E
x
∼
p
(
x
)
[
E
z
∼
p
(
z
∣
x
)
[
−
l
o
g
(
q
(
x
∣
z
)
)
]
+
D
K
L
(
p
(
z
∣
x
)
∣
∣
q
(
z
)
)
]
+
c
o
n
s
t
\begin{equation} \mathbb{D}_{KL}\left(p(x, z)\big|\big|q(x, z)\right) = \mathbb{E}_{x\sim p(x)}\left[\mathbb{E}_{z\sim p(z|x)}\bigg[-log\big( q(x|z) \big)\bigg] + \mathbb{D}_{KL}\bigg(p(z|x)\bigg|\bigg|q(z)\bigg)\right] + const \end{equation}
DKL(p(x,z)∣
∣∣
∣q(x,z))=Ex∼p(x)[Ez∼p(z∣x)[−log(q(x∣z))]+DKL(p(z∣x)∣
∣∣
∣q(z))]+const
其中,
c
o
n
s
t
=
E
x
∼
p
(
x
)
[
l
o
g
(
p
(
x
)
)
]
const = \mathbb{E}_{x\sim p(x)}\bigg[ log(p(x))\bigg]
const=Ex∼p(x)[log(p(x))]是常数,因为真实分布是不变的,且其期望是确定的。
1.3 损失函数的理解
- D K L ( p ( z ∣ x ) ∣ ∣ q ( z ) ) \mathbb{D}_{KL}\bigg(p(z|x)\bigg|\bigg|q(z)\bigg) DKL(p(z∣x)∣ ∣∣ ∣q(z))散度趋于零,表示希望两者分布能够趋于一致,如此一来,隐变量 z z z将与输入 x x x无关,因为只有两个随机变量无关时, p ( z ∣ x ) = q ( z ) p(z|x)=q(z) p(z∣x)=q(z)。用人话翻译一下就是,不希望输入 x x x和隐变量 z z z相关,从而提高模型的泛化能力。
- E z ∼ p ( z ∣ x ) [ − l o g ( q ( x ∣ z ) ) ] \mathbb{E}_{z\sim p(z|x)}\bigg[-log\big( q(x|z) \big)\bigg] Ez∼p(z∣x)[−log(q(x∣z))]趋于零表示,由 p ( x ∣ z ) p(x|z) p(x∣z)采样出来的隐变量 z ′ z^\prime z′所对应的条件概率 q ( x ∣ z ′ ) q(x|z^\prime) q(x∣z′)能够趋向于1。用人话翻译一下就是,希望编码器 p ( z ∣ x ) p(z|x) p(z∣x)得到的隐变量 z ′ z^\prime z′,输入到解码器 q ( x ∣ z ′ ) q(x|z^\prime) q(x∣z′)后能够以无限趋近于1的概率生成的 x x x。
这两者显然是相互矛盾的,第一点的泛化能力提高,会导致第二点模型还原能力下降, 反之亦然。这种天然的拮抗关系,是VAE模型最大的优势,因为从模型上就已经将两者平衡起来了,不需要设计者自己调整权重。
2. 总结
至此,我们并没有涉及到任何的网络结构,但从损失函数中,不难看出,为了组成一个能够满足我们在文章开头所畅想的优秀的生成网络所应该具备的基本条件,这个网络架构所需要用到的部件有 1). q ( z ) q(z) q(z), 2). p ( z ∣ x ) p(z|x) p(z∣x) 以及最后的 3). p ( x ∣ z ) p(x|z) p(x∣z)。下一篇博客将开始考虑实现VAE的具体细节。
祝周四分之七愉快!
2022年8月25日
Dianye Huang