文章目录
前言
我将看过的论文建了一个github库,方便各位阅读地址
传统的VAE,隐变量服从标准高斯分布(单峰),但有时候,单个高斯分布可能不能完全表达图像x的特征,比如MINIST数据集有0~9这10个数字,直觉上使用10个高斯分布来替代单个高斯分布更为合理,因此有学者将混合高斯分布模型(GMM)与VAE进行结合,其结果便是GMVAE。
FBI warning
本文为代码与论文结合进行理解的产物,如有错误,欢迎指出。本文不会进行ELBO的推导,将直接从论文给出的ELBO算式进行讲解。
GMVAE的生成过程
生成步骤如下:
说人话就是:
- 1a表示从标准正态分布中进行采样,得到 w w w,具体的采样方法我会写一篇博客进行说明
- 1b表示从Mult分布中采样 z = [ z 1 , z 2 , . . . z K ] z=[z_1,z_2,...z_K] z=[z1,z2,...zK], z z z其实是一个one-hot编码,其实可以自己随意指定
- 由于 z k z_k zk的取值非0即1,而 A 0 = 1 A^0=1 A0=1,所以1c表示从GMM中选择一个高斯分布进行采样,得到隐变量x,GMM中每个高斯分布的均值和方差将由步骤一采样到的 w w w进行变化得到,K为高斯分布个数
- 1d表示利用隐变量x生成图像y,由Decoder完成
GMVAE的损失函数
损失函数由变分推断推导而来,由于论文遗漏了太多推导细节,本文将不会介绍这部分推导,将重点介绍损失函数的各个部分如何计算。
与VAE一样,GMVAE通过最大化ELBO来进行优化,ELBO的形式如下:
L E L B O = E q ( x ∣ y ) [ p θ ( y ∣ x ) ] − E q ( w ∣ y ) p ( z ∣ x , w ) [ K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z ) ) ] − K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) − E q ( x ∣ y ) q ( w ∣ y ) [ K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) ] \begin{aligned} L_{ELBO}=&E_{q(x|y)}[p_\theta(y|x)]-E_{q(w|y)p(z|x,w)}[KL(q_{\phi_x}(x|y)||p_{\beta}(x|w,z))]\\ &-KL(q_{\phi_x}(w|y)||p(w))-E_{q_(x|y)q(w|y)}[KL(p_\beta(z|x,w)||p(z))] \end{aligned} LELBO=Eq(x∣y)[pθ(y∣x)]−Eq(w∣y)p(z∣x,w)[KL(qϕx(x∣y)∣∣pβ(x∣w,z))]−KL(qϕx(w∣y)∣∣p(w))−Eq(x∣y)q(w∣y)[KL(pβ(z∣x,w)∣∣p(z))]
ϕ x 、 θ 、 β \phi_x、\theta、\beta ϕx、θ、β表示待优化的参数,可以暂时忽视。
- E q ( x ∣ y ) [ p θ ( y ∣ x ) ] E_{q(x|y)}[p_\theta(y|x)] Eq(x∣y)[pθ(y∣x)]称为reconstruction term
- E q ( w ∣ y ) p ( z ∣ x , w ) [ K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z ) ) ] E_{q(w|y)p(z|x,w)}[KL(q_{\phi_x}(x|y)||p_{\beta}(x|w,z))] Eq(w∣y)p(z∣x,w)[KL(qϕx(x∣y)∣∣pβ(x∣w,z))]表示conditional prior term
- K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) KL(q_{\phi_x}(w|y)||p(w)) KL(qϕx(w∣y)∣∣p(w))表示w-prior term
- E q ( x ∣ y ) q ( w ∣ y ) [ K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) ] E_{q_(x|y)q(w|y)}[KL(p_\beta(z|x,w)||p(z))] Eq(x∣y)q(w∣y)[KL(pβ(z∣x,w)∣∣p(z))]表示z-prior term
接下来我将介绍每一部分的计算方式
reconstruction term
E q ( x ∣ y ) [ p θ ( y ∣ x ) ] E_{q(x|y)}[p_\theta(y|x)] Eq(x∣y)[pθ(y∣x)]表示重构误差,由于我们假定 p θ ( y ∣ x ) p_\theta(y|x) pθ(y∣x)服从高斯分布,所以与VAE一样,可以用均方误差进行计算。
conditional prior term
对 E q ( w ∣ y ) p ( z ∣ x , w ) [ K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z ) ) ] E_{q(w|y)p(z|x,w)}[KL(q_{\phi_x}(x|y)||p_{\beta}(x|w,z))] Eq(w∣y)p(z∣x,w)[KL(qϕx(x∣y)∣∣pβ(x∣w,z))]使用蒙特卡洛模拟,可得
1 M ∑ j = 1 M ∑ k = 1 K p β ( z k = 1 ∣ x ( j ) , w ( j ) ) K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w ( j ) , z k = 1 ) ) (1.0) \frac{1}{M}\sum_{j=1}^M\sum_{k=1}^Kp_{\beta}(z_k=1|x^{(j)},w^{(j)})KL(q_{\phi_x}(x|y)||p_\beta(x|w^{(j)},z_k=1))\tag{1.0} M1j=1∑Mk=1∑Kpβ(zk=1∣x(j),w(j))KL(qϕx(x∣y)∣∣pβ(x∣w(j),zk=1))(1.0)
M M M采样的样本数,我们可以将其设置为1,则1.0可变化为
∑ k = 1 K p β ( z k = 1 ∣ x , w ) K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z k = 1 ) ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) E q ϕ x ( x ∣ y ) [ log q ϕ x ( x ∣ y ) p β ( x ∣ w , z k = 1 ) ] = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) 1 N ∑ i = 1 N log q ϕ x ( x i ∣ y ) p β ( x i ∣ w , z k = 1 ) (2.0) \begin{aligned} &\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)KL(q_{\phi_x}(x|y)||p_\beta(x|w,z_k=1))\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)E_{q_{\phi_x}(x|y)}[\log\frac{q_{\phi_x}(x|y)}{p_\beta(x|w,z_k=1)}]\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\frac{1}{N}\sum_{i=1}^N\log\frac{q_{\phi_x}(x_i|y)}{p_\beta(x_i|w,z_k=1)} \end{aligned}\tag{2.0} ==k=1∑Kpβ(zk=1∣x,w)KL(qϕx(x∣y)∣∣pβ(x∣w,zk=1))k=1∑Kpβ(zk=1∣x,w)Eqϕx(x∣y)[logpβ(x∣w,zk=1)qϕx(x∣y)]k=1∑Kpβ(zk=1∣x,w)N1i=1∑Nlogpβ(xi∣w,zk=1)qϕx(xi∣y)(2.0)
第三行式子利用蒙特卡洛模拟得到,同理,将N设置为1,式2.0可变为
∑
k
=
1
K
p
β
(
z
k
=
1
∣
x
,
w
)
log
q
ϕ
x
(
x
∣
y
)
p
β
(
x
∣
w
,
z
k
=
1
)
=
∑
k
=
1
K
p
β
(
z
k
=
1
∣
x
,
w
)
log
q
ϕ
x
(
x
∣
y
)
−
∑
k
=
1
K
p
β
(
z
k
=
1
∣
x
,
w
)
log
p
β
(
x
∣
w
,
z
k
=
1
)
=
log
q
ϕ
x
(
x
∣
y
)
−
∑
k
=
1
K
p
β
(
z
k
=
1
∣
x
,
w
)
log
p
β
(
x
∣
w
,
z
k
=
1
)
(3.0)
\begin{aligned} &\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log\frac{q_{\phi_x}(x|y)}{p_\beta(x|w,z_k=1)}\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log q_{\phi_x}(x|y)-\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log p_\beta(x|w,z_k=1)\\ =&\log q_{\phi_x}(x|y)-\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log p_\beta(x|w,z_k=1)\tag{3.0} \end{aligned}
==k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)qϕx(x∣y)k=1∑Kpβ(zk=1∣x,w)logqϕx(x∣y)−k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)logqϕx(x∣y)−k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)(3.0)
K K K为混合高斯分布中高斯分布的个数,我们有如下假设:
- q ϕ x ( x ∣ y ) q_{\phi_x}(x|y) qϕx(x∣y)是一个多元高斯分布,其期望与方差为 μ ϕ x \mu^{\phi_x} μϕx、 ( δ ϕ x ) 2 (\delta^{\phi_x})^2 (δϕx)2。为方便书写与做图,本文将用一元高斯分布形式进行推导,请读者自行将推导结果中的期望与方差替换为多元高斯分布形式。
- 依据式1c, log p β ( x ∣ w , z k = 1 ) \log p_\beta(x|w,z_k=1) logpβ(x∣w,zk=1)是均值为 μ k β \mu^\beta_k μkβ,方差为 ( δ k β ) 2 (\delta^{\beta}_k)^2 (δkβ)2的多元高斯分布
则有
log
p
β
(
x
∣
w
,
z
k
=
1
)
=
log
1
2
π
δ
k
β
e
−
(
x
−
μ
k
β
)
2
2
(
δ
k
β
)
2
=
log
1
2
π
−
log
δ
k
β
−
(
x
−
μ
k
β
)
2
2
(
δ
k
β
)
2
(4.0)
\begin{aligned} \log p_\beta(x|w,z_k=1)&=\log\frac{1}{\sqrt {2\pi}\delta^\beta_k}e^{-\frac{(x-\mu^\beta_k)^2}{2(\delta^\beta_k)^2}}\\ &=\log \frac{1}{\sqrt{2\pi}}-\log \delta_k^\beta-\frac{(x-\mu^\beta_k)^2}{2(\delta^\beta_k)^2}\tag{4.0} \end{aligned}
logpβ(x∣w,zk=1)=log2πδkβ1e−2(δkβ)2(x−μkβ)2=log2π1−logδkβ−2(δkβ)2(x−μkβ)2(4.0)
log q ϕ x ( x ∣ y ) = log 1 2 π δ ϕ x e − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 = log 1 2 π δ ϕ x e − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 = log 1 2 π − log δ ϕ x − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 (5.0) \begin{aligned} \log q_{\phi_x}(x|y)&=\log \frac{1}{\sqrt {2\pi}\delta^{\phi_x}}e^{-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2}}\\ &=\log \frac{1}{\sqrt {2\pi}\delta^{\phi_x}}e^{-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2}}\\ &=\log \frac{1}{\sqrt{2\pi}}-\log \delta^{\phi_x}-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2} \end{aligned}\tag{5.0} logqϕx(x∣y)=log2πδϕx1e−2(δϕx)2(x−μϕx)2=log2πδϕx1e−2(δϕx)2(x−μϕx)2=log2π1−logδϕx−2(δϕx)2(x−μϕx)2(5.0)
x x x是服从 q ϕ x ( x ∣ y ) q_{\phi_x}(x|y) qϕx(x∣y)分布的样本,可以通过VAE提出的reparameterization trick得到
w-prior term
K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) KL(q_{\phi_x}(w|y)||p(w)) KL(qϕx(w∣y)∣∣p(w))有如下假设
- q ϕ x ( w ∣ y ) q_{\phi_x}(w|y) qϕx(w∣y)服从期望为 [ μ 1 ϕ w 、 μ 2 ϕ w . . . . . . μ n ϕ w ] [\mu_1^{\phi_w}、\mu_2^{\phi_w}......\mu_n^{\phi_w}] [μ1ϕw、μ2ϕw......μnϕw],方差为 [ ( δ 1 ϕ w ) 2 、 ( δ 2 ϕ w ) 2 . . . . . . ( δ n ϕ w ) 2 ] [(\delta_1^{\phi_w})^2、(\delta_2^{\phi_w})^2......(\delta_n^{\phi_w})^2] [(δ1ϕw)2、(δ2ϕw)2......(δnϕw)2]的独立多元高斯分布
- p ( w ) p(w) p(w)服从标准正态分布
则有
K
L
(
q
ϕ
x
(
w
∣
y
)
∣
∣
p
(
w
)
)
=
1
2
∑
i
=
1
n
(
(
μ
i
ϕ
w
)
2
+
(
δ
i
ϕ
w
)
2
−
1
−
log
(
δ
i
ϕ
w
)
2
)
(6.0)
\begin{aligned} KL(q_{\phi_x}(w|y)||p(w))=\frac{1}{2}\sum_{i=1}^n((\mu_i^{\phi_w})^2+(\delta_i^{\phi_w})^2-1-\log (\delta_i^{\phi_w})^2) \end{aligned}\tag{6.0}
KL(qϕx(w∣y)∣∣p(w))=21i=1∑n((μiϕw)2+(δiϕw)2−1−log(δiϕw)2)(6.0)
z-prior term
同理,对 E q ( x ∣ y ) q ( w ∣ y ) [ K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) ] E_{q_(x|y)q(w|y)}[KL(p_\beta(z|x,w)||p(z))] Eq(x∣y)q(w∣y)[KL(pβ(z∣x,w)∣∣p(z))]使用蒙特卡洛模拟,可得
1 M ∑ i = 1 M K L ( p β ( z ∣ x i , w i ) ∣ ∣ p ( z ) ) \frac{1}{M}\sum_{i=1}^MKL(p_\beta(z|x_i,w_i)||p(z)) M1i=1∑MKL(pβ(z∣xi,wi)∣∣p(z))
我们有如下假设
- p(z)为均匀分布,设 p ( z ) = 1 K p(z)=\frac{1}{K} p(z)=K1, K K K为混合高斯分布中高斯分布的个数
将M设置为1,则有
K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) log p β ( z k = 1 ∣ x , w ) p ( z k = 1 ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) [ log p β ( z k = 1 ∣ x , w ) + log K ] \begin{aligned} KL(p_\beta(z|x,w)||p(z))&=\sum_{k=1}^Kp_\beta(z_k=1|x,w)\log \frac{p_\beta(z_k=1|x,w)}{p(z_k=1)}\\ &=\sum_{k=1}^Kp_\beta(z_k=1|x,w)[\log p_\beta(z_k=1|x,w)+\log K] \end{aligned} KL(pβ(z∣x,w)∣∣p(z))=k=1∑Kpβ(zk=1∣x,w)logp(zk=1)pβ(zk=1∣x,w)=k=1∑Kpβ(zk=1∣x,w)[logpβ(zk=1∣x,w)+logK]
GMVAE的结构
本节结构为博主阅读代码后所得,博主没有复现GMVAE,故仅供参考。
图像生成的结构如下
如果您想了解更多有关深度学习、机器学习基础知识,或是java开发、大数据相关的知识,欢迎关注我们的公众号,我将在公众号上不定期更新深度学习、机器学习相关的基础知识,分享深度学习中有趣文章的阅读笔记。