变分自编码器(VAE)
变分自编码器(VAE)是通过大量的观测数据 x i \mathbf{x}_{i} xi 总结出数据的分布 p ( X ) p(\mathbf{X}) p(X),进而可以通过无穷次的采样获取所有的数据 X \mathbf{X} X,包含观测到的 x i \mathbf{x}_{i} xi 以及未观测到的 x j \mathbf{x}_{j} xj,这是个生成式模型, x j \mathbf{x}_{j} xj 就是生成结果。
然而分布 p ( X ) p(\mathbf{X}) p(X) 是不存在解析解的,我们构造一个参数化的分布 p θ ( X ) p_{\theta}(\mathbf{X}) pθ(X) 来逼近 p ( X ) p(\mathbf{X}) p(X),而优化这些参数的方法便是最大似然估计,即 θ ∗ = argmin θ ( − log ( p θ ( X ) ) ) \theta^{*}=\operatorname{argmin}_{\theta}(-\log(p_{\theta}(\mathbf{X}))) θ∗=argminθ(−log(pθ(X)))。
VAE受启发于自编码器,自编码器先将数据 x i \mathbf{x}_{i} xi 映射到一个低维的隐空间 z i \mathbf{z}_{i} zi, 再利用该隐变量恢复 x i \mathbf{x}_{i} xi,其目的是可以用低维的隐变量 z i \mathbf{z}_{i} zi 来准确表示高维输入数据 x i \mathbf{x}_{i} xi,这样一来能够对 x i \mathbf{x}_{i} xi 进行有效的数据压缩,用最本质的数据 z i \mathbf{z}_{i} zi 来表示 x i \mathbf{x}_{i} xi,剔除了大量与 x i \mathbf{x}_{i} xi 的本质无关的表示数据。
自编码器需要施加输入和输出相等的约束,VAE虽然架构上和自编码器相同,但是VAE不需要输入和输出相等,其追求的反而是输入和输出不相同,但是输入和输出的本质属性是相同的,不同的是附加在本质属性上的其他表现形式,也就是输入观测数据 x i \mathbf{x}_{i} xi,VAE提取出其本质属性表示 z \mathbf{z} z(这里没有加下标 i i i 是因为多个 x i \mathbf{x}_{i} xi的本质表示可能是相同的),然后基于本质表示 z \mathbf{z} z 附加一些额外的表示生成新的未观测过的数据 x j \mathbf{x}_{j} xj, x j \mathbf{x}_{j} xj 和 x i \mathbf{x}_{i} xi 在本质上是相同的。
为了方便理解,可以举一个简单的例子, x i \mathbf{x}_{i} xi 可以是”灰色的猫“,那么其本质属性表示 z \mathbf{z} z 就是”猫“,包含猫特有的属性(比如体态,眼睛,嘴巴,胡须等),但是对于颜色的属性(灰色), z \mathbf{z} z 将其滤除了,因为灰色并不是猫的本质属性,接下来VAE可以基于本质属性表示 z \mathbf{z} z 附加一些额外的属性生成新的未观测数据 x j \mathbf{x}_{j} xj — ”橘色的猫“,甚至也可以是”灰色的折耳猫“,这样 x j \mathbf{x}_{j} xj 和 x i \mathbf{x}_{i} xi 是不相等的,但是其本质上是相同的。
基于隐变量 z \mathbf{z} z,来估计 X \mathbf{X} X 的分布的方法是 p θ ( X ) = ∫ p θ ( X ∣ z ) p θ ( z ) d z p_{\theta}(\mathbf{X})=\int p_{\theta}(\mathbf{X}\mid\mathbf{z})p_{\theta}(\mathbf{z}) \mathrm{d} \mathbf{z} pθ(X)=∫pθ(X∣z)pθ(z)dz,设定 p θ ( z ) ∼ N ( 0 , I ) p_{\theta}(\mathbf{z})\sim\mathcal{N}(\mathbf{0}, \mathbf{I}) pθ(z)∼N(0,I),由于 z \mathbf{z} z 是通过压缩 X \mathbf{X} X 得到的,所以用 p θ ( z ∣ X ) p_{\theta}(\mathbf{z}\mid\mathbf{X}) pθ(z∣X) 来代替 p θ ( z ) p_{\theta}(\mathbf{z}) pθ(z),VAE的初步框架如下图所示:
但是
p
θ
(
z
∣
X
)
p_{\theta}(\mathbf{z}\mid\mathbf{X})
pθ(z∣X) 是不好求的,因为:
p
θ
(
z
∣
x
i
)
=
p
θ
(
x
i
∣
z
)
p
(
z
)
p
θ
(
x
i
)
=
p
θ
(
x
i
∣
z
)
p
(
z
)
∫
z
^
p
θ
(
x
i
∣
z
^
)
p
(
z
^
)
d
z
^
(1)
\begin{aligned} p_{\theta}\left(\mathbf{z} \mid \mathbf{x}_{i}\right) &=\frac{p_{\theta}\left(\mathbf{x}_{i} \mid \mathbf{z}\right) p(\mathbf{z})}{p_{\theta}\left(\mathbf{x}_{i}\right)} \\ &=\frac{p_{\theta}\left(\mathbf{x}_{i} \mid \mathbf{z}\right) p(\mathbf{z})}{\int_{\hat{\mathbf{z}}} p_{\theta}\left(\mathbf{x}_{i} \mid \hat{\mathbf{z}}\right) p(\hat{\mathbf{z}}) d \hat{\mathbf{z}}} \end{aligned} \tag{1}
pθ(z∣xi)=pθ(xi)pθ(xi∣z)p(z)=∫z^pθ(xi∣z^)p(z^)dz^pθ(xi∣z)p(z)(1) 分子好求,但是分母需要对
z
\mathbf{z}
z 进行大量的采样,这是不可行的,这里就引入了变分推理的方法,这就是为什么VAE叫变分自编码器了,引入
q
ϕ
(
z
∣
X
)
q_{\phi}\left(\mathbf{z} \mid \mathbf{X}\right)
qϕ(z∣X) 来近似
p
θ
(
z
∣
X
)
p_{\theta}(\mathbf{z}\mid\mathbf{X})
pθ(z∣X),所以VAE的框架变成了下图:
接下来就是推导损失函数,还是基于最大似然估计,计算使得
log
p
θ
(
X
)
\log p_{\theta}(\mathbf{X})
logpθ(X) 最大的
θ
\theta
θ。
log
p
θ
(
X
)
=
1
⋅
log
p
θ
(
X
)
=
(
∫
z
q
ϕ
(
z
∣
X
)
d
z
)
⋅
log
p
θ
(
X
)
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
)
d
z
log
p
θ
(
X
)
与
z
无关
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
,
z
)
p
θ
(
z
∣
X
)
d
z
贝
叶
斯
定
理
=
∫
z
q
ϕ
˙
(
z
∣
X
)
log
(
p
θ
(
X
,
z
)
q
ϕ
(
z
∣
X
)
⋅
q
ϕ
(
z
∣
X
)
p
θ
(
z
∣
X
)
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
,
z
)
q
ϕ
(
z
∣
X
)
d
z
+
∫
z
q
ϕ
(
z
∣
X
)
log
q
ϕ
(
z
∣
X
)
p
θ
(
z
∣
X
)
d
z
=
ℓ
(
p
θ
,
q
ϕ
)
+
D
K
L
(
q
ϕ
,
p
θ
)
≥
ℓ
(
p
θ
,
q
ϕ
)
K
L
散
度
非
负
.
(2)
\begin{aligned} \log p_{\theta}(\mathbf{X}) &=1 \cdot \log p_{\theta}(\mathbf{X}) \\ &=\left(\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) d \mathbf{z}\right) \cdot \log p_{\theta}(\mathbf{X}) \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log p_{\theta}(\mathbf{X}) d \mathbf{z} \quad \log p_{\theta}(\mathbf{X}) 与 \mathbf{z} \text { 无关 } \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{p_{\theta}(\mathbf{X}, \mathbf{z})}{p_{\theta}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z} ~~~~~~ 贝叶斯定理 \\ &=\int_{\mathbf{z}} q_{\dot{\phi}}(z \mid \mathbf{X}) \log \left(\frac{p_{\theta}(\mathbf{X}, z)}{q_{\phi}(z \mid \mathbf{X})} \cdot \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X})}{p_{\theta}(\mathbf{z} \mid \mathbf{X})}\right) d \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{p_{\theta}(\mathbf{X}, \mathbf{z})}{q_{\phi}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X})}{p_{\theta}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z} \\ &=\ell\left(p_{\theta}, q_{\phi}\right)+D_{K L}\left(q_{\phi}, p_{\theta}\right) \geq \ell\left(p_{\theta}, q_{\phi}\right) \quad K L 散度非负. \end{aligned} \tag{2}
logpθ(X)=1⋅logpθ(X)=(∫zqϕ(z∣X)dz)⋅logpθ(X)=∫zqϕ(z∣X)logpθ(X)dzlogpθ(X)与z 无关 =∫zqϕ(z∣X)logpθ(z∣X)pθ(X,z)dz 贝叶斯定理=∫zqϕ˙(z∣X)log(qϕ(z∣X)pθ(X,z)⋅pθ(z∣X)qϕ(z∣X))dz=∫zqϕ(z∣X)logqϕ(z∣X)pθ(X,z)dz+∫zqϕ(z∣X)logpθ(z∣X)qϕ(z∣X)dz=ℓ(pθ,qϕ)+DKL(qϕ,pθ)≥ℓ(pθ,qϕ)KL散度非负.(2) 将上式重新表示为:
ℓ
(
p
θ
,
q
ϕ
)
=
log
p
θ
(
X
)
−
D
K
L
(
q
ϕ
,
p
θ
)
(3)
\ell\left(p_{\theta}, q_{\phi}\right)=\log p_{\theta}(\mathbf{X})-D_{K L}\left(q_{\phi}, p_{\theta}\right) \tag{3}
ℓ(pθ,qϕ)=logpθ(X)−DKL(qϕ,pθ)(3) 则最大化
ℓ
(
p
θ
,
q
ϕ
)
\ell\left(p_{\theta}, q_{\phi}\right)
ℓ(pθ,qϕ) 就相当于最大化
log
p
θ
(
X
)
\log p_{\theta}(\mathbf{X})
logpθ(X) 和最小化
D
K
L
(
q
ϕ
,
p
θ
)
D_{K L}\left(q_{\phi}, p_{\theta}\right)
DKL(qϕ,pθ),所以优化目标也就变成了最大化
ℓ
(
p
θ
,
q
ϕ
)
\ell\left(p_{\theta}, q_{\phi}\right)
ℓ(pθ,qϕ),
ℓ
(
p
θ
,
q
ϕ
)
\ell\left(p_{\theta}, q_{\phi}\right)
ℓ(pθ,qϕ) 在变分推理中也叫 ELBO (Empirical Lower Bound)。进一步:
ℓ
(
p
θ
,
q
ϕ
)
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
,
z
)
q
ϕ
(
z
∣
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
X
)
d
z
贝叶斯定理
=
∫
z
q
ϕ
(
z
∣
X
)
log
p
(
z
)
q
ϕ
(
z
∣
X
)
d
z
+
∫
z
q
ϕ
(
z
∣
X
)
log
p
θ
(
X
∣
z
)
d
z
=
−
D
K
L
(
q
ϕ
,
p
)
+
E
q
ϕ
[
log
p
θ
(
X
∣
z
)
]
(4)
\begin{aligned} \ell\left(p_{\theta}, q_{\phi}\right)&=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{p_{\theta}(\mathbf{X}, \mathbf{z})}{q_{\phi}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z}\\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{p_{\theta}(\mathbf{X} \mid \mathbf{z}) p(\mathbf{z})}{q_{\phi}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z} \quad \text { 贝叶斯定理 }\\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log \frac{p(\mathbf{z})}{q_{\phi}(\mathbf{z} \mid \mathbf{X})} d \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}) \log p_{\theta}(\mathbf{X} \mid \mathbf{z}) d \mathbf{z} \\ &=-D_{K L}\left(q_{\phi}, p\right)+\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(\mathbf{X} \mid \mathbf{z})\right] \end{aligned} \tag{4}
ℓ(pθ,qϕ)=∫zqϕ(z∣X)logqϕ(z∣X)pθ(X,z)dz=∫zqϕ(z∣X)logqϕ(z∣X)pθ(X∣z)p(z)dz 贝叶斯定理 =∫zqϕ(z∣X)logqϕ(z∣X)p(z)dz+∫zqϕ(z∣X)logpθ(X∣z)dz=−DKL(qϕ,p)+Eqϕ[logpθ(X∣z)](4)
条件变分自编码器(CVAE)
在条件变分自编码器(CVAE)中,模型的输出就不是
x
j
\mathbf{x}_j
xj 了,而是对应于输入
x
i
\mathbf{x}_i
xi 的任务相关数据
y
i
\mathbf{y}_i
yi,例如分类任务就是长度为类别数的向量,所以损失函数得需重新推一遍,不过套路和VAE是一样的,这次的最大似然估计变成了
log
p
θ
(
Y
∣
X
)
\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})
logpθ(Y∣X),即:
log
p
θ
(
Y
∣
X
)
=
1
⋅
log
p
θ
(
Y
∣
X
)
=
(
∫
z
q
ϕ
(
z
∣
X
,
Y
)
d
z
)
log
p
θ
(
Y
∣
X
)
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
Y
∣
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
z
,
X
,
Y
)
p
θ
(
z
∣
X
,
Y
)
p
θ
(
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
z
∣
X
,
Y
)
p
θ
(
z
,
X
,
Y
)
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
z
∣
X
,
Y
)
d
z
+
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
z
,
X
,
Y
)
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
X
)
d
z
=
D
K
L
(
q
ϕ
,
p
θ
)
+
ℓ
(
p
θ
,
q
ϕ
)
(5)
\begin{aligned} \log p_{\theta}(\mathbf{Y}\mid\mathbf{X})&=1\cdot\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})\\ &=\left(\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\mathrm{d}\mathbf{z}\right)\log p_{\theta}(\mathbf{Y}\mid\mathbf{X}) \\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})}\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})}\mathrm{d}\mathbf{z}~+~\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=D_{K L}(q_{\phi}, p_{\theta}) ~+~ \ell(p_{\theta}, q_{\phi}) \tag{5} \end{aligned}
logpθ(Y∣X)=1⋅logpθ(Y∣X)=(∫zqϕ(z∣X,Y)dz)logpθ(Y∣X)=∫zqϕ(z∣X,Y)logpθ(Y∣X)dz=∫zqϕ(z∣X,Y)logpθ(z∣X,Y)pθ(X)pθ(z,X,Y)dz=∫zqϕ(z∣X,Y)logpθ(z∣X,Y)qϕ(z∣X,Y)qϕ(z∣X,Y)pθ(X)pθ(z,X,Y)dz=∫zqϕ(z∣X,Y)logpθ(z∣X,Y)qϕ(z∣X,Y)dz + ∫zqϕ(z∣X,Y)logqϕ(z∣X,Y)pθ(X)pθ(z,X,Y)dz=DKL(qϕ,pθ) + ℓ(pθ,qϕ)(5) 则 ELBO 为
ℓ
(
p
θ
,
q
ϕ
)
\ell(p_{\theta}, q_{\phi})
ℓ(pθ,qϕ),进一步:
ℓ
(
p
θ
,
q
ϕ
)
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
z
,
X
,
Y
)
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
Y
∣
X
,
Z
)
p
θ
(
Z
∣
X
)
p
θ
(
X
)
q
ϕ
(
z
∣
X
,
Y
)
p
θ
(
X
)
d
z
=
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
Z
∣
X
)
q
ϕ
(
z
∣
X
,
Y
)
d
z
+
∫
z
q
ϕ
(
z
∣
X
,
Y
)
log
p
θ
(
Y
∣
X
,
Z
)
d
z
=
−
D
K
L
(
q
ϕ
(
z
∣
X
,
Y
)
∣
p
θ
(
Z
∣
X
)
)
+
E
q
ϕ
[
log
p
θ
(
Y
∣
X
,
Z
)
]
(6)
\begin{aligned} \ell(p_{\theta}, q_{\phi})&=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})p_{\theta}(\mathbf{Z}\mid\mathbf{X})p_{\theta}(\mathbf{X})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})p_{\theta}(\mathbf{X})}\mathrm{d}\mathbf{z}\\ &=\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log\frac{p_{\theta}(\mathbf{Z}\mid\mathbf{X})}{q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})}\mathrm{d}\mathbf{z}~+~\int_{\mathbf{z}}q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\log p_{\theta}(\mathbf{Y}\mid\mathbf{X,\mathbf{Z}})\mathrm{d}\mathbf{z}\\ &=-D_{K L}(q_{\phi}(\mathbf{z}\mid\mathbf{X}, \mathbf{Y})\mid p_{\theta}(\mathbf{Z}\mid\mathbf{X}))~+~\mathbb{E}_{q_{\phi}}[\log p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})] \tag{6} \end{aligned}
ℓ(pθ,qϕ)=∫zqϕ(z∣X,Y)logqϕ(z∣X,Y)pθ(X)pθ(z,X,Y)dz=∫zqϕ(z∣X,Y)logqϕ(z∣X,Y)pθ(X)pθ(Y∣X,Z)pθ(Z∣X)pθ(X)dz=∫zqϕ(z∣X,Y)logqϕ(z∣X,Y)pθ(Z∣X)dz + ∫zqϕ(z∣X,Y)logpθ(Y∣X,Z)dz=−DKL(qϕ(z∣X,Y)∣pθ(Z∣X)) + Eqϕ[logpθ(Y∣X,Z)](6)
网络结构包含三个部分:
- 先验网络 p θ ( z ∣ X ) p_{\theta}(\mathbf{z}\mid\mathbf{X}) pθ(z∣X),如下图(b)所示
- Recognition网络 q ϕ ( z ∣ X , Y ) q_{\phi}(\mathbf{z}\mid\mathbf{X},\mathbf{Y}) qϕ(z∣X,Y), 如下图(c)所示
- Decoder网络 p θ ( Y ∣ X , Z ) p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z}) pθ(Y∣X,Z),如下图(b)所示。