变分自编码器(VAE)初识

AE回顾

Auto-Encoder,称自编码器,是一种无监督式学习模型。它基于反向传播算法与最优化方法(如梯度下降法),AE(Auto-Encoder)的架构可以如下所示;

在这里插入图片描述

X X X为整个数据集的集合, x i x_{i} xi是数据集中的一个样本。

自编码器包含一个编码器 z = g ( X ) z=g(X) z=g(X),称 z z z为编码,并且往往 z z z的维度远小于输入 X X X的维度。此外还包含一个解码器 X ~ = f ( z ) \tilde{X} =f(z) X~=f(z),这个解码器从编码 z z z中还原输入 X X X的信息。

我们自然希望 X ~ \tilde{X} X~ X X X尽可能地接近,因此我们可以定义自编码器的损失函数为 ℓ = ∥ X − X ~ ∥ \ell=\begin{Vmatrix} X-\tilde{X} \end{Vmatrix} = XX~ ,于是模型训练结束后,我们便可以认为 z z z蕴含了输入 X X X的大部分信息,则能够表达原始数据以实现数据降维的目的。

为此,我们可以将自编码器简记为:
X ∈ R C × H × W ⟶ z = g ( X ) ∈ R d ⟶ X ~ = f ( z ) ∈ R C × H × W X \in \mathbb{R}^{C\times H\times W}\longrightarrow z=g(X) \in \mathbb{R}^{d}\longrightarrow \tilde{X} =f(z) \in \mathbb{R}^{C\times H\times W} XRC×H×Wz=g(X)RdX~=f(z)RC×H×W其中 C C C表示通道数(彩色图片为RGB三通道), H , W H,W H,W分别表示图片数据的高和宽。

VAE

在AE那部分,可能有读者会想:能不能够直接从编码器 z = g ( X ) , z ∈ R d z=g(X),z \in \mathbb{R}^{d} z=g(X),zRd直接采样 z i z_{i} zi,再通过解码器 X ^ i = f ( z i ) \hat{X}_{i}=f(z_{i}) X^i=f(zi)生成图片呢?回答是很难,准确地来说一定概率下是可行的,但大概率会生成全是噪声的图片。这是因为 R d \mathbb{R}^{d} Rd是一个很大的空间,而符合条件的 z z z也许只占很小一分部,所以如果在整个 R d \mathbb{R}^{d} Rd空间上采样的话,自然难以生成符合预期的图片。也就是说由采样 z i z_{i} zi生成的 X ^ i = f ( z i ) \hat{X}_{i}=f(z_{i}) X^i=f(zi)可能会和 X i X_{i} Xi相距太远,但如果我们能够显式地对 z z z的分布 p ( z ) p(z) p(z)进行建模,从而只在 p ( z ) p(z) p(z)中采样 z z z,这便就是VAE(Variational Auto-Encoder),即变分自编码器。

VAE简述

不妨假设隐空间 z ∼ N ( 0 , I ) z\sim \mathcal N(0,I) zN(0,I),其中 I I I代表一个单位矩阵。记 X X X为随机变量, X i X_{i} Xi代表随机变量的样本。 z i z_{i} zi是从隐空间 z z z采样得到的样本(在VAE模型中并没有使用 p ( z ) p(z) p(z)(隐变量空间的分布)是标准正态分布的假设,实际上用的是假设 p ( z ∣ X ) p(z|X) p(zX)(后验分布)是服从正态分布的

具体来说,给定真实样本 x k x_{k} xk,假设存在一个服从多元正态分布的后验分布 p ( z ∣ x k ) p(z|x_{k}) p(zxk),且后验分布 p ( z ∣ x k ) p(z|x_{k}) p(zxk)对于不同的 k k k是相互独立的。事实上,在论文《Auto-Encoding Variational Bayes》中,也特别强调了这一点:

In this case, we can let the variational approximate posterior be a multivariate Gaussian with a diagonal covariance structure: log ⁡ q ϕ ( z ∣ x ( i ) ) = log ⁡ N ( z ; μ ( i ) , σ 2 ( i ) I ) ( 9 ) \log q_\phi(\boldsymbol{z}|\boldsymbol{x}^{(i)})=\log{\mathcal{N}(\boldsymbol{z};\boldsymbol{\mu}^{(i)},\boldsymbol{\sigma}^{2(i)}\boldsymbol{I})}\quad\quad\quad\quad(9) logqϕ(zx(i))=logN(z;μ(i),σ2(i)I)9

上式是实现VAE的关键,那么要如何找出后验分布 p ( z ∣ x i ) p(z|x_{i}) p(zxi)(已经假设是服从多元正态分布了)
的均值和方差呢?实际上我们可以通过从后验分布 p ( z ∣ x i ) p(z|x_{i}) p(zxi)采样的方式,然后利用神经网络拟合出来,设 μ i = f 1 ( X i ) μ_{i}=f_{1}(X_{i}) μi=f1(Xi) l o g σ i 2 = f 2 ( X i ) logσ^2_{i}=f_{2}(X_{i}) logσi2=f2(Xi),其中 μ i μ_{i} μi σ i σ_{i} σi分别表示后验分布 p ( z ∣ x i ) p(z|x_{i}) p(zxi)的均值和方差。这里要说明的是,选择拟合 l o g σ i 2 logσ^2_{i} logσi2是因为 σ i 2 σ^2_{i} σi2恒正,而 l o g σ i 2 logσ^2_{i} logσi2可正可负,这样就不用多添加一层激活函数层了。现在从已知的后验分布 p ( z ∣ x i ) p(z|x_{i}) p(zxi)中采样 z i z_{i} zi,通过解码器得到 x ^ k = f ( z k ) \hat{x}_{k}=f(z_{k}) x^k=f(zk),如果VAE训练得足够好(也就是 D ( x , x ^ k ) \mathcal{D}(x,\hat{x}_{k}) D(x,x^k)足够小, D \mathcal{D} D表示距离函数),我们可以认为 x ^ k \hat{x}_{k} x^k是与 x k x_{k} xk都来源于数据集 X X X的分布 p ( X ) p(X) p(X)

上述的过程我们表示为:设数据集是由某个随机过程生成的,而 z z z是随机过程中的一个不可观测到的隐空间。这个生成数据的随机过程包含两个步骤:

首先从先验分布 p ( z ) p(z) p(z)中采样得到一个 z i z_{i} zi,再从条件分布 p ( X ∣ z i ) p(X|z_{i}) p(Xzi)中采样得到一个数据点 x i x_{i} xi,如果能基于这个随机过程进行建模,那么便能够得到一个生成模型。

VAE中的Decoder

下面我们从生成模型的角度来看看VAE中的Decoder架构:
在这里插入图片描述
上图是VAE中的Decoder架构的示意图。现在给Decoder一个从 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)采样得到的 z i z_{i} zi,我们希望由 θ \theta θ参数化的Decoder能够输出 z i z_{i} zi对应的 X X X的分布,即 f θ ( z i ) = p θ ( X ∣ z i ) f_{\theta}(z_{i})=p_{\theta}(X|z_{i}) fθ(zi)=pθ(Xzi)

假设给定任意 z i z_{i} zi后, X X X都服从某个各维度独立的多元高斯分布,即:
p θ ( X ∣ z i ) = N ( X ∣ μ i ′ ( z i ; θ ) , σ i ′ 2 ( z i ; θ ) ∗ I ) \begin{align}p_{\theta}(X|z_{i})=\mathcal N(X|\mu_{i}^{'}(z_{i};\theta), \sigma_{i}^{'2}(z_{i};\theta)*I)\end{align} pθ(Xzi)=N(Xμi(zi;θ),σi2(zi;θ)I)于是只要输入 z i z_{i} zi给Decoder,然后让它拟合出 μ i ′ \mu_{i}^{'} μi σ i ′ 2 \sigma_{i}^{'2} σi2,就能知道 X ∣ z i X|z_{i} Xzi的具体分布了。

如果所有的 p ( z ∣ x k ) p(z|x_{k}) p(zxk)都很接近标准正态分布 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1),那么根据定义
p ( Z ) = ∑ X p ( Z ∣ X ) p ( X ) = ∑ X N ( 0 , I ) p ( X ) = N ( 0 , I ) ∑ X p ( X ) = N ( 0 , I ) ( 2 ) p(Z)=∑Xp(Z|X)p(X)=∑XN(0,I)p(X)=N(0,I)∑Xp(X)=N(0,I)(2) p(Z)=Xp(ZX)p(X)=XN(0,I)p(X)=N(0,I)Xp(X)=N(0,I)(2)

这样我们就能达到我们的先验假设:p(Z)
是标准正态分布。然后我们就可以放心地从N(0,I)
中采样来生成图像了。

VAE的目标函数

下面我们从统计的视角来看看一个生成模型的目标函数应该是怎么样的。采用李宏毅老师的课件图,生成模型的目标可以如下图所示:(图中的 x x x z z z和本文VAE中的定义不完全一致)

在这里插入图片描述
生成模型是目标对数据集的分布 p ( X ) p(X) p(X)建模得到 θ \theta θ参数化的分布 p θ ( X ) p_{\theta}(X) pθ(X),从分布 p θ ( X ) p_{\theta}(X) pθ(X)中采样进而得到生成的数据。如上图所示, X X X代表一个图片的集合,得到 p θ ( X ) p_{\theta}(X) pθ(X)采样一个令 p θ ( x i ) p_{\theta}(x_{i}) pθ(xi)较大的 x i x_{i} xi,那么 x i x_{i} xi很可能就是和数据集合 X X X中的图片相似的一张新图片。

根据前面的内容,我们可以如下计算出 p θ ( X ) p_{\theta}(X) pθ(X):已知 p ( z ) ∼ N ( 0 , I ) , p θ ( X ∣ z i ) = N ( X ∣ μ i ′ ( z i ; θ ) , σ i ′ 2 ( z i ; θ ) ∗ I ) \begin{align}p(z)\sim \mathcal N(0,I),p_{\theta}(X|z_{i})=\mathcal N(X|\mu_{i}^{'}(z_{i};\theta), \sigma_{i}^{'2}(z_{i};\theta)*I)\end{align} p(z)N(0,I),pθ(Xzi)=N(Xμi(zi;θ),σi2(zi;θ)I)那么有: p θ ( X ) = ∫ z p θ ( X ∣ z ) p ( z ) d z ≈ 1 m ∑ j = 1 m p θ ( X ∣ z j ) \begin{align}p_{\theta}(X)=\int_{z}p_{\theta}(X|z)p(z)dz\approx \frac{1}{m} {\textstyle \sum_{j=1}^{m}} p_{\theta}(X|z_{j})\end{align} pθ(X)=zpθ(Xz)p(z)dzm1j=1mpθ(Xzj)利用极大似然估计(MLE),为了让数据集出现的概率最大化,也就是:
θ ∗ = a r g m i n θ − ∑ i = 1 n l o g ( p θ ( x i ) ) = a r g m i n θ − ∑ i = 1 n l o g ( ∑ j = 1 m p θ ( x i ∣ z j ) ) \begin{align}\theta^{*}&=argmin_{\theta }-\sum_{i=1}^{n} log(p_{\theta}(x_{i}))\notag\\ &=argmin_{\theta }-\sum_{i=1}^{n} log({\textstyle \sum_{j=1}^{m}} p_{\theta}(x_{i}|z_{j}))\end{align} θ=argminθi=1nlog(pθ(xi))=argminθi=1nlog(j=1mpθ(xizj))因为往往 x i x_{i} xi的维度会很大(比如RGB图片256*256), z j z_{j} zj的维度也不会很低( z z z是编码器得到的语义向量,比如一张图片可以抽取十个特征),并且对于某个 x i x_{i} xi而言,与之强相关的 z j z_{j} zj的数量是相对有限的,但是为了找到这些有限的 z j z_{j} zj,可能要进行大量的采样,而这在现实应用的意义下是几乎不可能的(这将会是一件computation costly的事情),所以我们希望能够以较小的代价采样得到更多强相关的 z j z_{j} zj,解决这个需求的办法便是在Encoder中引入后验分布 p θ ( z ∣ x i ) p_{\theta}(z|x_{i}) pθ(zxi)

VAE中的Encoder

如下图所示:

在这里插入图片描述
前一部分( X → Z X→Z XZ)可以看成是Encoder的架构,假设有后验分布 p θ ( z ∣ X ) p_{\theta}(z|X) pθ(zX),正向过程在计算的时候将 x i x_{i} xi传给Encoder,算出 z ∣ x i z|x_{i} zxi服从的分布后再从这个分布中采样出 z j z_{j} zj,再将采样得到的 z j z_{j} zj传给Decoder,然后便能够得到 X ∣ z i X|z_{i} Xzi的分布,然后再用极大似然估计使得 p ( X ∣ z i ) p(X|z_{i}) p(Xzi)最大化。如此采样得到的 z j z_{j} zj几乎是和 x i x_{i} xi相关的,这样便能保证采样的效率。

上述过程可以由贝叶斯公式计算得出: p θ ( z ∣ x i ) = p θ ( x i ∣ z ) p ( z ) p θ ( x i ) = p θ ( x i ∣ z ) p ( z ) ∫ z ′ p θ ( X ∣ z ′ ) p ( z ′ ) d z ′ \begin{align}p_{\theta}(z|x_{i})&=\frac{p_{\theta}(x_{i}|z) p(z)}{p_{\theta}(x_{i})}\notag\\&=\frac{p_{\theta}(x_{i}|z) p(z)}{\int_{z'}p_{\theta}(X|z')p(z')dz'}\end{align} pθ(zxi)=pθ(xi)pθ(xiz)p(z)=zpθ(Xz)p(z)dzpθ(xiz)p(z)我们依然可以通过采样大量的 z i z_{i} zi来计算上式分母的积分,但这同样是一个computation costly的工作,这时便能够应用变分贝叶斯:设由 ϕ \phi ϕ参数化的Encoder能够拟合对任意 x i x_{i} xi的分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_i{}) qϕ(zxi),如果这个分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_i{}) qϕ(zxi)能足够逼近真实的后验分布 p θ ( z ∣ x i ) p_{\theta}(z|x_{i}) pθ(zxi)的话,便能够直接由Encoder得到 z ∣ x i z|x_{i} zxi的分布。由于分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_i{}) qϕ(zxi)是多元高斯分布,因此只需要让Encoder输出参数 μ \mu μ ∑ 2 \sum^{2} 2即可。同时由前面的假设[公式(2)],由贝叶斯公式以及条件概率公式有: p θ ( z ∣ X ) = p ( z ) p θ ( X ∣ z ) p ( X ) \begin{align}p_{\theta}(z|X)=\frac{p(z)p_{\theta}(X|z)}{p(X)}\end{align} pθ(zX)=p(X)p(z)pθ(Xz)于是后验分布 p θ ( z ∣ X ) p_{\theta}(z|X) pθ(zX)也是多元高斯分布。假设近似后验分布对于任意的 x i x_{i} xi有:
q ϕ ( X ∣ z i ) = N ( X ∣ μ i ( z i ; ϕ ) , σ i 2 ( z i ; ϕ ) ∗ I ) \begin{align}q_{\phi}(X|z_{i})=\mathcal N(X|\mu_{i}^{}(z_{i};\phi), \sigma_{i}^{2}(z_{i};\phi)*I)\end{align} qϕ(Xzi)=N(Xμi(zi;ϕ),σi2(zi;ϕ)I),即近似后验分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_i{}) qϕ(zxi)也是一个各维度独立的多元正态分布。

VAE的架构

下图是VAE的架构图(基于MLP模型),其中 x i ( j ) x_{i}^{(j)} xi(j)表示第 i i i个数据点的第 j j j个特征。

在这里插入图片描述

由上图先给出VAE架构总结:
1.首先给Encoder输入一个数据点 x i x_{i} xi,通过神经网络,得到隐变量 z z z服从的近似后验分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_{i}) qϕ(zxi)的参数。Encoder输出服从的高斯分布的参数 σ i 2 \sigma_{i}^{2} σi2 μ i \mu_{i} μi
2.从高斯分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_{i}) qϕ(zxi)中采样 z j z_{j} zj,这个 z j z_{j} zj应当代表与 x i x_{i} xi相似的一类样本。
3.令Decoder拟合似然的分布 p θ ( X ∣ z j ) p_{\theta}(X|z_{j}) pθ(Xzj)。输入Decoder一个 z j z_{j} zj,它应当返回服从 X ∣ z j X|z_{j} Xzj的分布的参数。令Decoder输出服从的高斯分布的参数 σ i ′ 2 \sigma_{i}^{'2} σi2 μ i ′ \mu_{i}{'} μi即可。
4.在得到 X ∣ z j X|z_{j} Xzj的分布的参数后,从这个分布中进行采样,来生成可能的数据点x_{i}。

注意:大部分实现中,人们往往不进行采样,而是直接将模型输出的 μ i ′ \mu_{i}{'} μi当作是给定生成的数据点。除此之外,人们也往往认为是一个固定方差的各维度独立的多元高斯分布,即 σ i ′ 2 \sigma_{i}^{'2} σi2是一个人为给定的超参数。

可以看到,我们并没有用到 p ( Z ) p(Z) p(Z)是正态分布的假设。

参数重整化

为了让网络能够进行训练(能够反向传播和正向传播以计算梯度),利用参数重整化的技巧,替换步骤2中部分:从高斯分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_{i}) qϕ(zxi)中采样 z j z_{j} zj的服从 z ∼ N ( μ i , σ i 2 ) z\sim \mathcal N(\mu_{i}{},\sigma_{i}^{2}) zN(μiσi2),令 ϵ i ∼ N ( 0 , I ) \epsilon_{i}\sim \mathcal N(0,I) ϵiN(0,I), z = μ i + σ i ⊙ ϵ i z=\mu_{i}+\sigma_{i} \odot \epsilon_{i} z=μi+σiϵi,其中 ⊙ \odot 表示点积,则 z z z是服从 N ( μ i , σ i 2 ∗ I ) \mathcal N(\mu_{i}, \sigma_{i}^{2}*I) N(μi,σi2I)的各维度独立的多元高斯分布,并且z是可以求梯度的。那么VAE的架构可以如图表示:

在这里插入图片描述

VAE的损失函数

利用对数极大似然估计,最大化 l o g p θ ( X ) logp_{\theta}(X) logpθ(X)时,利用变分推断的思想,得到如下结果: log ⁡ p θ ( X ) = ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X ) d z 全概率公式 = ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X , z ) p θ ( z ∣ X ) d z 贝叶斯公式 = ∫ z q ϕ ( z ∣ X ) log ⁡ ( p θ ( X , z ) q ϕ ( z ∣ X ) ⋅ q ϕ ( z ∣ X ) p θ ( z ∣ X ) ) d z 恒等变换 = ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X , z ) q ϕ ( z ∣ X ) d z + ∫ z q ϕ ( z ∣ X ) log ⁡ q ϕ ( z ∣ X ) p θ ( z ∣ X ) d z 拆开log = ℓ ( p θ , q ϕ ) + D K L ( q ϕ , p θ ) ≥ ℓ ( p θ , q ϕ ) K L 散度非负 . \begin{aligned} \log p_{\theta}(X)& =\int_zq_\phi(z\mid X)\log p_\theta(X)dz\quad\text{全概率公式} \\ &=\int_zq_\phi(z\mid X)\log\frac{p_\theta(X,z)}{p_\theta(z\mid X)}dz\quad\text{贝叶斯公式} \\ &=\int_zq_\phi(z\mid X)\log\biggl(\frac{p_\theta(X,z)}{q_\phi(z\mid X)}\cdot\frac{q_\phi(z\mid X)}{p_\theta(z\mid X)}\biggr)dz \quad\text{恒等变换}\\ &=\int_zq_\phi(z\mid X)\log\frac{p_\theta(X,z)}{q_\phi(z\mid X)}dz+\int_zq_\phi(z\mid X)\log\frac{q_\phi(z\mid X)}{p_\theta(z\mid X)}dz \quad\text{拆开log}\\ &=\ell\left(p_\theta,q_\phi\right)+\mathcal{D}_{KL}\left(q_\phi,p_\theta\right) \\ &\geq\ell\left(p_\theta,q_\phi\right)\quad KL\text{散度非负}. \end{aligned} logpθ(X)=zqϕ(zX)logpθ(X)dz全概率公式=zqϕ(zX)logpθ(zX)pθ(X,z)dz贝叶斯公式=zqϕ(zX)log(qϕ(zX)pθ(X,z)pθ(zX)qϕ(zX))dz恒等变换=zqϕ(zX)logqϕ(zX)pθ(X,z)dz+zqϕ(zX)logpθ(zX)qϕ(zX)dz拆开log=(pθ,qϕ)+DKL(qϕ,pθ)(pθ,qϕ)KL散度非负.

还记得我们在Pytorch中说到的关于交叉熵以及KL散度的一些概念吗?我们在那说明了KL散度的非负性。上式中 ℓ ( p θ , q ϕ ) \ell\left(p_\theta,q_\phi\right) (pθ,qϕ)显然是一个下界,这也是称 ℓ \ell 为ELBO(Evidence Lower Bound Objection)的原因。变换上述式子可以得到: ℓ ( p θ , q ϕ ) = l o g p θ ( X ) − D K L ( q ϕ , p θ ) \ell\left(p_\theta,q_\phi\right)=log p_{\theta}(X)-\mathcal{D}_{KL}\left(q_\phi,p_\theta\right) (pθ,qϕ)=logpθ(X)DKL(qϕ,pθ),于是最小化 ℓ ( p θ , q ϕ ) \ell\left(p_\theta,q_\phi\right) (pθ,qϕ)等价于最大化 l o g p θ ( X ) log p_{\theta}(X) logpθ(X)
,同时最小化 D K L ( q ϕ , p θ ) \mathcal{D}_{KL}\left(q_\phi,p_\theta\right) DKL(qϕ,pθ),也就是让 q ϕ , p θ q_\phi,p_\theta qϕ,pθ足够接近。其中 ℓ ( p θ , q ϕ ) \ell\left(p_\theta,q_\phi\right) (pθ,qϕ)代表近似后验分布 q ϕ q_\phi qϕ和真实后验分布 p θ p_\theta pθ之间的损失。

展开 ℓ ( p θ , q ϕ ) \ell\left(p_\theta,q_\phi\right) (pθ,qϕ)得到: ℓ ( p θ , q ϕ ) = ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X , z ) q ϕ ( z ∣ X ) d z = ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X ∣ z ) p ( z ) q ϕ ( z ∣ X ) d z 贝叶斯公式 = ∫ z q ϕ ( z ∣ X ) log ⁡ p ( z ) q ϕ ( z ∣ X ) d z + ∫ z q ϕ ( z ∣ X ) log ⁡ p θ ( X ∣ z ) d z = − D K L ( q ϕ , p ) + E q ϕ [ log ⁡ p θ ( X ∣ z ) ] . \begin{aligned} \ell\left(p_{\theta},q_{\phi}\right)& =\int_zq_\phi(z\mid X)\log\frac{p_\theta(X,z)}{q_\phi(z\mid X)}dz \\ &=\int_zq_\phi(z\mid X)\log\frac{p_\theta(X\mid z)p(z)}{q_\phi(z\mid X)}dz\quad\text{贝叶斯公式} \\ &=\int_zq_\phi(z\mid X)\log\frac{p(z)}{q_\phi(z\mid X)}dz+\int_zq_\phi(z\mid X)\log p_\theta(X\mid z)dz \\ &=-\mathcal{D}_{KL}\left(q_\phi,p\right)+\mathbb{E}_{q_\phi}\left[\log p_\theta(X\mid z)\right]. \end{aligned} (pθ,qϕ)=zqϕ(zX)logqϕ(zX)pθ(X,z)dz=zqϕ(zX)logqϕ(zX)pθ(Xz)p(z)dz贝叶斯公式=zqϕ(zX)logqϕ(zX)p(z)dz+zqϕ(zX)logpθ(Xz)dz=DKL(qϕ,p)+Eqϕ[logpθ(Xz)].,第一项 − D K L ( q ϕ , p ) -\mathcal{D}_{KL}\left(q_\phi,p\right) DKL(qϕ,p)为正则项,人们称为Latent loss,由于我们已经在前文假设了 q ϕ ( z ∣ X ) q_\phi(z|X) qϕ(zX) p ( z ) p(z) p(z)都服从高斯分布(公式(2)和公式(7)),且都是各维度独立的高斯分布,结合KL散度的性质可以得到其解析解,以一维情况为例: ( N ( μ , σ 2 ) ∥ N ( 0 , 1 ) ) = ∫ z 1 2 π σ 2 e x p ( − ( z − μ ) 2 2 σ 2 ) log ⁡ 1 2 π σ 2 e x p ( − ( z − μ ) 2 2 σ 2 ) 1 2 π e x p ( − z 2 2 ) d z = ∫ z ( − ( z − μ ) 2 2 σ 2 + z 2 2 − log ⁡ σ ) N ( μ , σ 2 ) d z = − ∫ z ( z − μ ) 2 2 σ 2 N ( μ , σ 2 ) d z + ∫ z z 2 2 N ( μ , σ 2 ) d z − ∫ z log ⁡ σ N ( μ , σ 2 ) d z \begin{aligned} \left(\mathcal{N}\left(\mu,\sigma^2\right)\|\mathcal{N}(0,1)\right)&=\int_z\frac{1}{\sqrt{2\pi\sigma^2}}\mathrm{exp}\left(-\frac{\left(z-\mu\right)^2}{2\sigma^2}\right)\log\frac{\frac{1}{\sqrt{2\pi\sigma^2}}\mathrm{exp}\left(-\frac{\left(z-\mu\right)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\mathrm{exp}\left(-\frac{z^2}{2}\right)}dz \\ &=\int_z\left(\frac{-(z-\mu)^2}{2\sigma^2}+\frac{z^2}2-\log\sigma\right)\mathcal{N}\left(\mu,\sigma^2\right)dz \\ &=-\int_z\frac{\left(z-\mu\right)^2}{2\sigma^2}\mathcal{N}\left(\mu,\sigma^2\right)dz+\int_z\frac{z^2}2\mathcal{N}\left(\mu,\sigma^2\right)dz-\int_z\log\sigma\mathcal{N}\left(\mu,\sigma^2\right)dz \end{aligned} (N(μ,σ2)N(0,1))=z2πσ2 1exp(2σ2(zμ)2)log2π 1exp(2z2)2πσ2 1exp(2σ2(zμ)2)dz=z(2σ2(zμ)2+2z2logσ)N(μ,σ2)dz=z2σ2(zμ)2N(μ,σ2)dz+z2z2N(μ,σ2)dzzlogσN(μ,σ2)dz整个结果分为三项积分,第一项实际上就是 − l o g σ 2 −logσ^2 logσ2乘以概率密度的积分(也就是1),所以结果是 − l o g σ 2 −logσ^2 logσ2;第二项实际是正态分布的二阶矩,熟悉正态分布的读者应该都清楚正态分布的二阶矩为 μ 2 + σ 2 μ^2+σ^2 μ2+σ2;而根据定义,第三项实际上就是“-方差除以方差=-1”。所以总结果就是 1 2 ( − 1 + σ 2 + μ 2 − log ⁡ σ 2 ) \frac{1}{2}(-1+\sigma^{2}+\mu^{2}-\log\sigma^{2}) 21(1+σ2+μ2logσ2),推广到多元高斯分布可以得到: D K L ( q ϕ ( z ∣ X ) , p ( z ) ) = ∑ j = 1 d 1 2 ( − 1 + σ ( j ) 2 + μ ( j ) 2 − log ⁡ σ ( j ) 2 ) . D_{KL}\left(q_{\phi}(z\mid X),p(z)\right)=\sum_{j=1}^{d}\frac{1}{2}(-1+\sigma^{(j)}{}^{2}+\mu^{(j)}{}^{2}-\log\sigma^{(j)}{}^{2}). DKL(qϕ(zX),p(z))=j=1d21(1+σ(j)2+μ(j)2logσ(j)2).其中 σ ( j ) 2 \sigma^{(j)2} σ(j)2表示向量 a a a的第 j j j个元素的平方。

E q ϕ [ log ⁡ p θ ( X ∣ z ) ] \mathbb{E}_{q_\phi}\left[\log p_\theta(X\mid z)\right] Eqϕ[logpθ(Xz)]称为Reconstruction Loss,通常可以通过采样的方式估计其值: E q ϕ [ log ⁡ p θ ( X ∣ z ) ] ≈ 1 m ∑ i = 1 m log ⁡ p θ ( X ∣ z i ) , \mathbb{E}_{q_\phi}\left[\log p_\theta(X\mid z)\right]\approx\frac{1}{m}\sum_{i=1}^{m}\log p_\theta\left(X\mid z_i\right), Eqϕ[logpθ(Xz)]m1i=1mlogpθ(Xzi),其中 z i ∼ q ϕ ( z ∣ x i ) = N ( z ∣ μ ( x i ; ϕ ) , σ 2 ( x i ; ϕ ) ∗ I ) 。 z_i\sim q_\phi\left(z\mid x_i\right)=\mathcal{N}\left(z\mid\mu\left(x_i;\phi\right),\sigma^2\left(x_i;\phi\right)*I\right)\text{。} ziqϕ(zxi)=N(zμ(xi;ϕ),σ2(xi;ϕ)I)

若假设数据为固定方差的高斯分布,则极大似然估计后得到的目标函数等价于极小平方估计。设每个数据点 x i x_{i} xi的维度为 K K K,即 X ∣ z i X|z_{i} Xzi服从一个K维高斯分布,易得: log ⁡ p θ ( X ∣ z i ) = log ⁡ exp ⁡ ( − 1 2 ( X − μ ′ ) T Σ ′ − 1 ( X − μ ′ ) ) ( 2 π ) k ∣ Σ ′ ∣ = − 1 2 ( X − μ ′ ) T Σ ′ − 1 ( X − μ ′ ) − log ⁡ ( 2 π ) k ∣ Σ ′ ∣ = − 1 2 ∑ k = 1 K ( X ( k ) − μ ′ ( k ) ) 2 σ ′ ( k ) − log ⁡ ( 2 π ) K ∏ k = 1 K σ ′ ( k ) . \begin{aligned} \log p_{\theta}\left(X\mid z_{i}\right)& =\log\frac{\exp\left(-\frac{1}{2}(X-\mu^{\prime})^{\mathrm{T}}\Sigma^{\prime{-1}}(X-\mu^{\prime})\right)}{\sqrt{(2\pi)^k|\Sigma^{\prime}|}} \\ &=-\frac12(X-\mu^{\prime})^{\mathrm{T}}\Sigma^{\prime{-1}}(X-\mu^{\prime})-\log\sqrt{(2\pi)^k|\Sigma^{\prime}|} \\ &=-\frac12\sum_{k=1}^K\frac{(X^{(k)}-\mu^{\prime{(k)}})^2}{\sigma^{\prime{(k)}}}-\log\sqrt{(2\pi)^K\prod_{k=1}^K\sigma^{\prime{(k)}}}. \end{aligned} logpθ(Xzi)=log(2π)kΣ exp(21(Xμ)TΣ1(Xμ))=21(Xμ)TΣ1(Xμ)log(2π)kΣ =21k=1Kσ(k)(X(k)μ(k))2log(2π)Kk=1Kσ(k) .,至此,对于数据集 X X X的每一个数据 x i x_{i} xi都进行上次计算,然后计算总的损失函数: L = − 1 n ∑ i = 1 n ℓ ( p θ , q ϕ ) = 1 n ∑ i = 1 n D K L ( q ϕ , p ) − 1 n ∑ i = 1 n E q ϕ [ log ⁡ p θ ( x i ∣ z ) ] = 1 n ∑ i = 1 n D K L ( q ϕ , p ) − 1 n m ∑ i = 1 n ∑ j = 1 m log ⁡ p θ ( x i ∣ z j ) . \begin{aligned} \mathcal{L}& =-\frac1n\sum_{i=1}^n\ell(p_\theta,q_\phi) \\ &=\frac1n\sum_{i=1}^nD_{KL}\left(q_\phi,p\right)-\frac1n\sum_{i=1}^n\mathbb{E}_{q_\phi}\left[\log p_\theta(x_i\mid z)\right] \\ &=\frac{1}{n}\sum_{i=1}^{n}D_{KL}\left(q_{\phi},p\right)-\frac{1}{nm}\sum_{i=1}^{n}\sum_{j=1}^{m}\log p_{\theta}\left(x_{i}\mid z_{j}\right). \end{aligned} L=n1i=1n(pθ,qϕ)=n1i=1nDKL(qϕ,p)n1i=1nEqϕ[logpθ(xiz)]=n1i=1nDKL(qϕ,p)nm1i=1nj=1mlogpθ(xizj).,前面提到过的采样以进行近似计算是一件非常computation costly的工作,相比之下这里的采样计算实际上是从分布 q ϕ ( z ∣ x i ) q_{\phi}(z\mid x_{i}) qϕ(zxi)中采样得到的 z j z_{j} zj,在网络的训练过程中,近似分布很快就会逼近分布 log ⁡ p θ ( x i ∣ z ) \log p_{\theta}(x_{i}\mid z) logpθ(xiz),这样一便很大可能能够在有限次数的采样中,采样到与 x i x_{i} xi关联的 z j z_{j} zj

在编写代码时只对一个 x i x_{i} xi只采样一个 z j z_{j} zj,即 m = 1 m=1 m=1就可以达到不错的效果。所以可以以将损失改写为:KaTeX parse error: Double superscript at position 194: …+\sigma_i^{(j)}^̲2+\mu_i^{(j)}^2…由于在前文假设了 p θ ( X ∣ z i ) p_{\theta}(X \mid z_i) pθ(Xzi)对任意 z i z_{i} zi均是方差固定的各维度独立的 K K K维高斯分布,不妨令超参数 σ σ σ为元素值全为 1 2 \frac{1}{2} 21 K K K维向量。则损失函数可以继续改为写:

其中, x i x_{i} xi代表第i个样本,是Encoder的输入。 μ i μ_{i} μi σ i 2 \sigma_i^2 σi2是Encoder的输出,代表 z ∣ x i z∣x_{i} zxi的分布的参数。 z i z_{i} zi是从 z ∣ x i z∣x_{i} zxi中采样得到的一个样本,它是Decoder的输入。 μ i ′ μ_{i}^{'} μi σ i ′ 2 \sigma_{i}^{'2} σi2是Decoder的输出,用来表示 z i z_{i} zi解码后对应的数据点 f ( z i ) = x ^ i f(z_{i})=\hat{x}_{i} f(zi)=x^i

至此,得到了在假设先验和后验分布均是高斯分布的情况下,VAE最的损失函数。而采用高斯分布只是因为其推导时的简便性,我们同样可以根据数据的情况,假设更加复杂分布来推导、训练VAE。如果使用更加复杂的分布,VAE也许会更难训练,但是对于现实情况的表示会更好。

CVAE

训练好了VAE后,仍然存在一个问题,虽然我们能够通过采样一个 z i z_{i} zi来生成一个对应的 x i x_{i} xi,但无法保证 x i x_{i} xi是属于哪一类的图片。以MNIST手写数字数据集为例,我们无法保证生成对应的数字,因为每次生成的结果都是随机的。因为目前的VAE是无监督训练的,因此很自然想到:如果利用有标签的数据来训练VAE,得到的模型能不能根据标签来生成对应的图片呢?答案是肯定的,这种情况叫做Conditional VAE,或者叫CVAE。

CVAE的实现思路相比VAE也很简单,只需要做些修改。假设我们现在的数据集为 X X X
,以及数据集 X X X的标签 Y Y Y,那么我们只需要在训练VAE的时候,每一步都考虑标签 Y Y Y即可。即:
1.原来MLE是对 p θ ( X ) p_{\theta}(X) pθ(X)建模,那么现在需要对建模 p θ ( X ∣ Y ) p_{\theta}(X|Y) pθ(XY)
2.原来是对 p ( z ) p(z) p(z)进行建模,现在是对 p ( z ∣ y i ) p(z|y_{i}) p(zyi)
3.原来的Encoder是对近似后验分布 q ϕ ( z ∣ x i ) q_{\phi}(z|x_{i}) qϕ(zxi)进行建模,现在是对近似后验分布 q ϕ ( z ∣ x i , y i ) q_{\phi}(z|x_{i},y_{i}) qϕ(zxi,yi)
4.原来Decoder是对 p θ ( X ∣ z i ) p_{\theta}(X|z_{i}) pθ(Xzi)估计,现在即是对 p θ ( X ∣ z i , y i ) p_{\theta}(X|z_{i},y_{i}) pθ(Xzi,yi)估计

在前面的讨论中,我们希望数据集 X X X经过编码后, z z z的分布都具有零均值和单位方差,这个“希望”是通过加入了KL loss来实现的。如果现在多了类别信息 Y Y Y
,我们可以希望同一个类的样本都有一个专属的均值 μ Y μ^Y μY(方差不变,还是单位方差),这个 μ Y μ^Y μY让模型自己训练出来。这样的话,有多少个类就有多少个正态分布,而在生成的时候,我们就可以通过控制均值来控制生成图像的类别。事实上,这样可能也是在VAE的基础上加入最少的代码来实现CVAE的方案。

测试代码:

苏剑林老师的代码地址如下:VAE

下面是一个基于手写数据集的简单实现:

import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt


class VAE(nn.Module):

    def __init__(self, in_features, latent_size, y_size=0):
        super(VAE, self).__init__()

        self.latent_size = latent_size
        
		# 编码器网络层
        self.encoder_forward = nn.Sequential(
            nn.Linear(in_features + y_size, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, self.latent_size * 2)
        )
        
		# 解码器网络层
        self.decoder_forward = nn.Sequential(
            nn.Linear(self.latent_size + y_size, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.Sigmoid()
        )
	# 编码器,返回mu和sigma^2
    def encoder(self, X):
        out = self.encoder_forward(X)
        mu = out[:, :self.latent_size]
        log_var = out[:, self.latent_size:]
        return mu, log_var
        
	# 解码器,返回mu'来作为采样结果
    def decoder(self, z):
        mu_prime = self.decoder_forward(z)
        return mu_prime
        
	# 参数重整化
    def reparameterization(self, mu, log_var):
        epsilon = torch.randn_like(log_var)
        z = mu + epsilon * torch.sqrt(log_var.exp())
        return z

	# 损失函数
    def loss(self, X, mu_prime, mu, log_var):
        # reconstruction_loss = F.mse_loss(mu_prime, X, reduction='mean') is wrong!
        reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))

        latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
        return reconstruction_loss + latent_loss

	# 前向传播
    def forward(self, X, *args, **kwargs):
        mu, log_var = self.encoder(X)
        z = self.reparameterization(mu, log_var)
        mu_prime = self.decoder(z)
        return mu_prime, mu, log_var


class CVAE(VAE):

    def __init__(self, in_features, latent_size, y_size):
        super(CVAE, self).__init__(in_features, latent_size, y_size)

    def forward(self, X, y=None, *args, **kwargs):
        y = y.to(next(self.parameters()).device)
        X_given_Y = torch.cat((X, y.unsqueeze(1)), dim=1)

        mu, log_var = self.encoder(X_given_Y)
        z = self.reparameterization(mu, log_var)
        z_given_Y = torch.cat((z, y.unsqueeze(1)), dim=1)

        mu_prime_given_Y = self.decoder(z_given_Y)
        return mu_prime_given_Y, mu, log_var


def train(model, optimizer, data_loader, device, name='VAE'):
    model.train()

	# 总损失,一个epoch的损失
    total_loss = 0
    pbar = tqdm(data_loader) # 时间条
    for X, y in pbar:
        batch_size = X.shape[0]
        X = X.view(batch_size, -1).to(device) # 拉成向量,维度(batch_size,自动计算)
        model.zero_grad() # 梯度清零

        if name == 'VAE':
            mu_prime, mu, log_var = model(X)
        else:
            mu_prime, mu, log_var = model(X, y)

		# 损失函数
        loss = model.loss(X.view(batch_size, -1), mu_prime, mu, log_var)
        loss.backward() # 反向传播
        optimizer.step() # 参数更新

        total_loss += loss.item()
        pbar.set_description('Loss: {loss:.4f}'.format(loss=loss.item()))

    return total_loss / len(data_loader) # 返回平均损失


@torch.no_grad()
def save_res(vae, cvae, data, latent_size, device):
    num_classes = len(data.classes)

    # raw samples from dataset
    out = []
    for i in range(num_classes):
        img = data.data[torch.where(data.targets == i)[0][:num_classes]] # 对于每个num_class选取num_claass个图片
        out.append(img)
    out = torch.stack(out).transpose(0, 1).reshape(-1, 1, 28, 28) # out的维度(num_classes ** 2,1,28,28)
    save_image(out.float(), './img/raw_samples.png', nrow=num_classes, normalize=True)

    # samples generated by vanilla VAE
    z = torch.randn(num_classes ** 2, latent_size).to(device)
    out = vae.decoder(z)
    save_image(out.view(-1, 1, 28, 28), './img/vae_samples.png', nrow=num_classes)

    # sample generated by CVAE
    z = torch.randn(num_classes ** 2, latent_size).to(device)
    y = torch.arange(num_classes).repeat(num_classes).to(device)
    z_given_Y = torch.cat((z, y.unsqueeze(1)), dim=1)
    out = cvae.decoder(z_given_Y)
    save_image(out.view(-1, 1, 28, 28), './img/cvae_samples.png', nrow=num_classes)
    
def plot_data(data):
    # 带坐标轴的显示部分图片
    
    # data = MNIST('../../data/', download=True, transform=transforms.ToTensor())
    num_classes = len(data.classes)
    out = []
    for i in range(num_classes):
        img = data.data[torch.where(data.targets == i)[0][:num_classes]]
        out.append(img)
    d = torch.stack(out).transpose(0, 1)
    out = torch.stack(out).transpose(0, 1).reshape(-1, 1, 28, 28) # out的维度(num_class **2, 1, 28, 28)
    tensor = torchvision.transforms.ToPILImage()(out[0]) # 三维Tensor(channel,height,width)
    tensor = out[0].permute(1,2,0) # (channel,height,width)→(height,width,channel),这是因为plt.imshow的格式为(h,w,c)
    plt.imshow(tensor)
    plt.show()

    for i in range(num_classes **2): # i从0到99
        plt.subplot(10,10,i+1)
        tensor = torchvision.transforms.ToPILImage()(out[i]) # 三维Tensor(channel,height,width)
        tensor = out[i].permute(1,2,0) # (channel,height,width)→(height,width,channel),这是因为plt.imshow的格式为(h,w,c)
        plt.imshow(tensor)
    plt.show()
plot_data(data)

def plot_data1(data):
    # 不带坐标轴的显示部分图片
    
    # data = MNIST('../../data/', download=True, transform=transforms.ToTensor())
    num_classes = len(data.classes)
    out = []
    for i in range(num_classes):
        img = data.data[torch.where(data.targets == i)[0][:num_classes]]
        out.append(img)
    d = torch.stack(out).transpose(0, 1)
    out = torch.stack(out).transpose(0, 1).reshape(-1, 1, 28, 28) # out的维度(num_class **2, 1, 28, 28)
    tensor = torchvision.transforms.ToPILImage()(out[0]) # 三维Tensor(channel,height,width)
    tensor = out[0].permute(1,2,0) # (channel,height,width)→(height,width,channel),这是因为plt.imshow的格式为(h,w,c)
    current_axes=plt.axes()
    current_axes.xaxis.set_visible(False)
    current_axes.yaxis.set_visible(False)
    plt.imshow(tensor)
    plt.show()
    
    
    fig, axs = plt.subplots(num_classes, num_classes) # fig.subplots()”是生成一个画布一个子图时用的,如果想要在同一个画布中生成多个子图,需要用“subplot()”函数
#     print(axs.shape) # 查看axs维度
    for i in range(num_classes **2): # i从0到99
        j = i // 10 # 商,代表行
        k = i % 10 # 余数,代表列
        tensor = torchvision.transforms.ToPILImage()(out[i]) # 三维Tensor(channel,height,width)
        tensor = out[i].permute(1,2,0) # (channel,height,width)→(height,width,channel),这是因为plt.imshow的格式为(h,w,c)
        off = "off" 
        # exec() 函数可以理解为执行一段写在字符串中的代码语句。
        exec(f'axs[{j}][{k}].axis(off)') # 去除坐标轴
        exec(f'axs[{j}][{k}].xaxis.set_ticks([])')  # 去除X轴
        exec(f'axs[{j}][{k}].yaxis.set_ticks([])')  # 去除Y轴
        exec(f'axs[{j}][{k}].imshow(tensor, interpolation=None)')
        plt.imshow(tensor)
        
    plt.show()
    
plot_data1(data)

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device) # 有GPU就是用GPU

    batch_size = 256 * 4
    epochs = 50
    latent_size = 64
    in_features = 28 * 28
    lr = 0.001

	#构建数据集
    data = MNIST('../../dataset/', download=True, transform=transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

    # train VAE
    # 模型
    vae = VAE(in_features, latent_size).to(device)
    # 优化器
    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)

    print('Start Training VAE...')
    # 开始迭代,一共迭代epoch次
    for epoch in range(1, 1 + epochs):
        loss = train(vae, optimizer, data_loader, device, name='VAE')
        print("Epochs: {epoch}, AvgLoss: {loss:.4f}".format(epoch=epoch, loss=loss))
    print('Training for VAE has been done.')

    # train VCAE
    cvae = CVAE(in_features, latent_size, y_size=1).to(device)
    optimizer = torch.optim.AdamW(cvae.parameters(), lr=lr)

    print('Start Training CVAE...')
    for epoch in range(1, 1 + epochs):
        loss = train(cvae, optimizer, data_loader, device, name='CVAE')
        print("Epochs: {epoch}, AvgLoss: {loss:.4f}".format(epoch=epoch, loss=loss))
    print('Training for CVAE has been done.')

    save_res(vae, cvae, data, latent_size, device)


if __name__ == '__main__':
    main()

参考博客

机器学习方法—优雅的模型(一):变分自编码器(VAE)
变分自编码器(一):原来是这么一回事

  • 25
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值