变分自编码器-推断输入分布的有向概率模型

变分自编码器-推断输入分布的有向概率模型

0.摘要

本文主要分析一种用于生成新样本数据的有向概率模型——变分自编码器。总结相关的数学知识和方法,并对模型进行详细的分析与推导。用PyTorch深度学习框架搭建相关模型,在MNIST数据集上进行实验并分析结果。最后给出模型改进和拓展。
关键词: 有向概率模型,后验分布推断,蒙特卡罗采样,变分贝叶斯估计,深度学习,变分自编码器(VAE) ,GAN,图像生成

1.简介


在这里插入图片描述


变分自编码器(VAE)是一种有向概率模型,本质是生成模型。它假设我们得到的样本都是服从某个复杂分布 P ( X ) P(X) P(X), 即 x ∼ P ( X ) x\sim P(X) xP(X), 生成模型的目的就是要建模输入数据的分布 P ( X ) P(X) P(X),这样我们就可以从分布中进行采样,得到新的样本数据。
由于数据分布函数估计不准确,后验分布难以处理,由采样无法直接得到。为了从模型生成样本,VAE从编码分布(通常假设为高斯分布) p ( z ) p(z) p(z) 中采样隐变量 z z z,然后使样本通过可微生成器网络 f ( z ; θ ) f(z;\theta) f(z;θ),将采样的隐变量 z z z映射为与X比较相似的样本数据,即使 p ( X ∣ z ) p(X|z) p(Xz)的概率更高。最后从分布 p θ ( x ; f ( z ; θ ) ) = p θ ( x ∣ z ) p_{\theta}(x;f(z;\theta))=p_{\theta}(x|z) pθ(x;f(z;θ))=pθ(xz)中采样 x x x。同时搭建一个近似推断网络 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx),在模型的训练期间用于获得 z z z。与自编码器进行对比发现,生成器网络可视为解码器(Decoder), 而近似推断网络则可视为编码器(Encoder)。

2.数学基础及方法

2.1 贝叶斯法则(Bayesian Law)

当我们知道事件 A A A的先验概率为 P ( A ) P(A) P(A),事件B的先验概率为 P ( B ) P(B) P(B)以及事件 A A A发生后事件 B B B发生的条件概率( B B B的后验概率) P ( B ∣ A ) P(B|A) P(BA), 我们可以根据贝叶斯法则得到事件 B B B发生后事件 A A A发生的条件概率( A A A的后验概率) P ( A ∣ B ) P(A|B) P(AB)。即公式:

P ( A ∣ B ) = P ( A ) P ( B ∣ A ) P ( B ) P(A|B)=\frac{P(A)P(B|A)}{P(B)} P(AB)=P(B)P(A)P(BA)
2.2 蒙特卡罗方法(Monte Carlo Method)

当无法精确计算和或积分时,通常可以使用蒙特卡罗采样来近似它。这种想法把和或者积分视作某分布下的期望,然后通过估计对应的平均值来近似这个期望。即令

s = ∑ x p ( x ) f ( x ) = E p [ f ( x ) ] s=\sum_{x}p(x)f(x)=E_{p}[f(x)] s=xp(x)f(x)=Ep[f(x)]

或者

s = ∫ p ( x ) f ( x ) = E p [ f ( x ) ] s=\int p(x)f(x)=E_{p}[f(x)] s=p(x)f(x)=Ep[f(x)]

p p p是一个关于随机变量 x x x的概率分布或者概率密度函数
我们可以通过从 p p p中抽取n个样本 x ( 1 ) , ⋅ ⋅ ⋅ , x ( n ) x^{(1)},···,x^{(n)} x(1),,x(n)来近似s并得到一个经验平均值

s n ^ = 1 n ∑ i = 1 n f ( x ( i ) ) \hat{s_n}=\frac{1}{n}\sum^{n}_{i=1}f(x^{(i)}) sn^=n1i=1nf(x(i))

并且有

E [ s n ^ ] = 1 n ∑ i = 1 n E [ f ( x ( i ) ) ] = 1 n ∑ i = 1 n s = s E[\hat{s_n}]=\frac{1}{n}\sum^{n}_{i=1}E[f(x^{(i)})]=\frac{1}{n}\sum^n_{i=1}s=s E[sn^]=n1i=1nE[f(x(i))]=n1i=1ns=s

容易观察到 s n ^ \hat{s_n} sn^这个估计是无偏的。此外,根据大数定理,如果样本 x ( i ) x^{(i)} x(i)是独立同分布的,那么该均值几乎必然会收敛到真实的期望值,即:

lim ⁡ n → + ∞ s n ^ = s {\lim_{n \to +\infty}}\hat{s_n}=s n+limsn^=s
2.3 K L KL KL散度(KL-Divergence)

如果对于同一个随机变量 x x x有两个单独的概率分布 P ( x ) P(x) P(x) Q ( x ) Q(x) Q(x),可以用 K L KL KL散度(相对熵)来衡量这两个分布的差异:

D K L ( P ∣ ∣ Q ) = E x ~ P [ l o g P ( x ) Q ( x ) ] = E x ~ P [ l o g P ( x ) − l o g Q ( x ) ] D_{KL}(P||Q)=E_{x~P}[log\frac{P(x)}{Q(x)}]=E_{x~P}[logP(x)-logQ(x)] DKL(PQ)=ExP[logQ(x)P(x)]=ExP[logP(x)logQ(x)]
2.4 最大似然值估计(Maximum Likelihood Estimation)

最大似然估计是对不同模型中得到特定函数作为好的估计的准则。考虑一组含有m个样本的数据集 X = { x ( 1 ) , ⋅ ⋅ ⋅ , x ( m ) } X={\{x^{(1)},···,x^{(m)}\}} X={x(1),,x(m)}独立地由未知的真实数据生成分布 p d a t a ( x ) p_{data}(x) pdata(x)生成。令 p m o d e l ( x ; θ ) p_{model}(x;\theta) pmodel(x;θ)为将输入x映射到实数来估计真实概率 p d a t a ( x ) p_{data}(x) pdata(x)的相同空间上的概率分布,对 θ \theta θ的最大似然估计定义为:

θ M L = arg ⁡ max ⁡ θ p m o d e l ( X ; θ ) = arg ⁡ max ⁡ θ ∏ i = 1 m p m o d e l ( x ( i ) ; θ ) \theta_{ML}= \mathop{\arg\max}\limits_{\theta}p_{model}(X;\theta)=\mathop{\arg\max}\limits_{\theta}\prod^{m}_{i=1}p_{model}(x^{(i)};\theta) θML=θargmaxpmodel(X;θ)=θargmaxi=1mpmodel(x(i);θ)

因为多个乘积难以计算,可能会出现数值下溢,所以我们重写为对数似然的形式。又因为重新放缩代价函数时结果不会改变,我们除以m得到和训练数据经验分布 p ^ d a t a \hat{p}_{data} p^data相关的期望作为准则:

θ M L = arg ⁡ max ⁡ θ   E x ~ p d a t a ^   l o g   p m o d e l ( x ; θ ) \theta_{ML}= \mathop{\arg\max}\limits_{\theta}\ E_{x~\hat{p_{data}}}\ log \ p_{model}(x;\theta) θML=θargmax Expdata^ log pmodel(x;θ)

将最大似然估计扩展到估计条件概率 P ( y   ∣   x ; θ ) P(y\ |\ x;\theta) P(y  x;θ),从而给定 x x x预测 y y y

θ M L = arg ⁡ max ⁡ θ P ( Y   ∣   X ; θ ) \theta_{ML}= \mathop{\arg\max}\limits_{\theta}P(Y\ |\ X;\theta) θML=θargmaxP(Y  X;θ)

假设样本是独立同分布的,改为对数似然形式,那么上式可以分解为:

θ M L = arg ⁡ max ⁡ θ ∑ i = 1 m   l o g   P ( y ( i )   ∣   x ( i ) ; θ ) \theta_{ML}= \mathop{\arg\max}\limits_{\theta}\sum^m_{i=1}\ log\ P(y^{(i)}\ |\ x^{(i)};\theta) θML=θargmaxi=1m log P(y(i)  x(i);θ)
2.5 变分贝叶斯推断(Variational Bayes Inference)

简单来说,变分贝叶斯推断目的是为了近似数据的后验分布,它的一般流程为:( i i i)确定好研究模型各个参数的的共轭先验分布,( i i ii ii)写出研究模型的联合分布 P ( Z , X ) P(Z,X) P(Z,X), ( i i i iii iii)根据联合分布确定变分分布的形式 Q ( Z ) Q(Z) Q(Z), ( i v iv iv)对于每个变分因子 Q ( z j ) Q(z_j) Q(zj)求出 P ( Z , X ) P(Z,X) P(Z,X)关于不包含变量 z j z_j zj的数学期望,再规整化为概率分布。具体数学公式推导及应用会在第三部分中进行分析。

3.变分自编码器推导

3.1 主体思想推导

现在我们来进行变分自编码器的理论推导。首先,我们假设有一组输入样本数据为 X = { X 1 , X 2 , ⋅ ⋅ ⋅ X N } X=\{X_1,X_2,···X_N\} X={X1,X2,XN},我们的目的是要求得 X X X的概率分布 P ( X ) P(X) P(X)。然而数据量不够大,直接估计不准确。于是我们认为图片的特征有一些列连续的隐变量控制。因此,由全概率公式可将X的概率分布写为:

P ( X ) = ∫ p ( x ∣ z ) p ( z )   d z P(X)=\int p(x|z)p(z)\ dz P(X)=p(xz)p(z) dz

引入一个隐变量 Z Z Z之后,比如 P ( Z ) ∼ N ( 0 , I P(Z)\sim \mathcal N(0,I P(Z)N(0,I),然后得到 X X X相对于随机变量 Z Z Z的条件分布。如果可以实现,则是从标准高斯分布中进行采样一个 z z z,然后基于 z z z去计算产生一个 X X X,从而可以得到 X X X的分布。然而,现在仅仅知道 p ( z ) p(z) p(z),而后验分布 p ( x ∣ z ) p(x|z) p(xz)未知,我们并不能得到 X X X的概率分布。于是,我们使用神经网络把 p ( z ) p(z) p(z)中采样得到的数据 z z z映射为与 X X X比较相似的样本数据。即训练映射 f ( z ; θ ) f(z;\theta) f(z;θ)(相当于解码器) ,则有 p θ ( x ∣ z ) = N ( f ( z ; θ ) , σ 2 I ) p_{\theta}(x|z)=\mathcal N(f(z;\theta),\sigma^2I) pθ(xz)=N(f(z;θ),σ2I),其中 σ \sigma σ 是一个超参数。我们可以证明,高斯混合模型是概率密度的万能近似器,在这种意义下,任何平滑的概率密度都可以用具有足够多组件的高斯混合模型以任意精度来逼近。也就是说任意一个 d d d维的复杂分布都可以通过对 d d d维高斯分布使用一个复杂的变换得到。因此,给定一个表达能力足够强的函数,可以将服从高斯分布的隐变量映射为模型需要的隐变量,再映射为 x x x
现在我们该如何去计算或者说训练这个神经网络 f ( z ; θ ) f(z;\theta) f(z;θ)?由贝叶斯法则我们知道,真实的后验分布为:

p θ ( z ∣ x ) = p θ ( x ∣ z ) p ( z ) p θ ( x ) p_{\theta}(z|x)=\frac{p_{\theta}(x|z)p(z)}{p_{\theta}(x)} pθ(zx)=pθ(x)pθ(xz)p(z)

所以,现在的关键在于后验分布。使用相同的策略,我们训练一个神经网络映射为 g ( x ; ϕ ) g(x;\phi) g(x;ϕ)(相当于编码器)生成一个后验分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)来逼近真实的后验分布 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx)。这时候,我们就使用 K L KL KL散度来度量两个分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx) p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx),即:

D K L [ q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ] = ∫ q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) log ⁡ q ϕ ( z ∣ x ) p θ ( z ∣ x ) D_{KL}[q_{\phi}(z|x)||p_{\theta}(z|x)] = \int_{q_{\phi}(z|x)}q_{\phi}(z|x)\log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)}\\ DKL[qϕ(zx)pθ(zx)]=qϕ(zx)qϕ(zx)logpθ(zx)qϕ(zx)
= E q ϕ ( z ∣ x ) [ l o g   q ϕ ( z ∣ x ) − log ⁡ p θ ( z ∣ x ) ]                       =E_{q_{\phi}(z|x)}[log\ q_{\phi}(z|x)-\log p_{\theta}(z|x)]\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =Eqϕ(zx)[log qϕ(zx)logpθ(zx)]                     
= E q ϕ ( z ∣ x ) [ l o g   q ϕ ( z ∣ x ) − log ⁡ p θ ( x , z ) + log ⁡ p θ ( x ) ]   =E_{q_{\phi}(z|x)}[log\ q_{\phi}(z|x)-\log p_{\theta}(x,z)+\log p_{\theta}(x)] \ =Eqϕ(zx)[log qϕ(zx)logpθ(x,z)+logpθ(x)] 
= E q ϕ ( z ∣ x ) [ l o g   q ϕ ( z ∣ x ) − log ⁡ p θ ( x , z ) ] + log ⁡ p θ ( x )   (1) =E_{q_{\phi}(z|x)}[log\ q_{\phi}(z|x)-\log p_{\theta}(x,z)]+\log p_{\theta}(x) \ \tag{1} =Eqϕ(zx)[log qϕ(zx)logpθ(x,z)]+logpθ(x) (1)

对(1)式进行整理得到最大似然值估计:

log ⁡ p θ ( x ) = E q ϕ ( z ∣ x ) [ − l o g   q ϕ ( z ∣ x ) + log ⁡ p θ ( x , z ) ] + D K L [ q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ] \log p_{\theta}(x)=E_{q_{\phi}(z|x)}[-log\ q_{\phi}(z|x)+\log p_{\theta}(x,z)]+D_{KL}[q_{\phi}(z|x)||p_{\theta}(z|x)] logpθ(x)=Eqϕ(zx)[log qϕ(zx)+logpθ(x,z)]+DKL[qϕ(zx)pθ(zx)]

由于 D K L [ q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ] ≥ 0 D_{KL}[q_{\phi}(z|x)||p_{\theta}(z|x)]\geq0 DKL[qϕ(zx)pθ(zx)]0 恒成立,所以:

log ⁡ p θ ( x ) ≥ E q ϕ ( z ∣ x ) [ − l o g   q ϕ ( z ∣ x ) + log ⁡ p θ ( x , z ) ] ) \log p_{\theta}(x) \geq E_{q_{\phi}(z|x)}[-log\ q_{\phi}(z|x)+\log p_{\theta}(x,z)]) logpθ(x)Eqϕ(zx)[log qϕ(zx)+logpθ(x,z)])

设:

L ( θ , ϕ ; x , z ) = E q ϕ ( z ∣ x ) [ − l o g   q ϕ ( z ∣ x ) + log ⁡ p θ ( x , z ) ] ) (2) \mathcal L(\theta,\phi;x,z)=E_{q_{\phi}(z|x)}[-log\ q_{\phi}(z|x)+\log p_{\theta}(x,z)]) \tag{2} L(θ,ϕ;x,z)=Eqϕ(zx)[log qϕ(zx)+logpθ(x,z)])(2)

为变分下界,从而求边际似然的最大值专为求变分下界的最大值。
根据最大似然值法,有:

log ⁡ p ( X ) = ∑ i = 1 N log ⁡ p ( x ( i ) ) \log p(X)=\sum^N_{i=1}\log p(x^{(i)}) logp(X)=i=1Nlogp(x(i))

其中

log ⁡ p ( x ( i ) ) = L ( θ , ϕ ; x ( i ) , z ( i ) ) + D K L [ q ϕ ( z ( i ) ∣ x ( i ) ) ∣ ∣ p θ ( z ( i ) ∣ x ( i ) ) ] \log p(x^{(i)})=\mathcal L(\theta,\phi;x^{(i)},z^{(i)})+D_{KL}[q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)}|x^{(i)})] logp(x(i))=L(θ,ϕ;x(i),z(i))+DKL[qϕ(z(i)x(i))pθ(z(i)x(i))]

对变分下界做变换:

L ( θ , ϕ ; x , z ) = E q ϕ ( z ∣ x ) [ − log ⁡ q ϕ ( z ∣ x ) + log ⁡ p ( x , z ) ]                              \mathcal{L}(\theta, \phi;x,z) = E_{q_{\phi}(z|x)}[-\log q_{\phi}(z|x)+\log p(x,z)] \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ L(θ,ϕ;x,z)=Eqϕ(zx)[logqϕ(zx)+logp(x,z)]                            
= E q ϕ ( z ∣ x ) [ − log ⁡ q ϕ ( z ∣ x ) + log ⁡ ( p θ ( x ∣ z ) p ( z ) ) ] =E_{q_{\phi}(z|x)}[-\log q_{\phi}(z|x)+\log (p_{\theta}(x|z)p(z))] =Eqϕ(zx)[logqϕ(zx)+log(pθ(xz)p(z))]
                 = − E q ϕ ( z ∣ x ) [ log ⁡ q ϕ ( z ∣ x ) + p ( z ) ] + E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ = -E_{q_{\phi}(z|x)}[\log q_{\phi}(z|x)+p(z)] + E_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)]                 =Eqϕ(zx)[logqϕ(zx)+p(z)]+Eqϕ(zx)[logpθ(xz)]
    = − D K L [ q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ] + E q ϕ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] (3) \ \ \ = -D_{KL}[q_{\phi}(z|x)||p(z)] + E_{q_{\phi}(z|x)}[log p_{\theta}(x|z)] \tag{3}    =DKL[qϕ(zx)p(z)]+Eqϕ(zx)[logpθ(xz)](3)
3.2 重参数化技巧

现在存在一个问题,由于 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)是后验分布,而隐变量 z z z是通过蒙特卡罗采样得到的,这样估计真实分布的方法是不可导的,也就不能进行反向传播。于是我们使用重参数化技巧来解决这个问题。即引入一个新的分布 z = g ϕ ( ϵ , x ) , ϵ z = g_{\phi}(\epsilon, x), \epsilon z=gϕ(ϵ,x),ϵ为辅助变量。具体来说,先从标准高斯分布 N ( 0 , I ) \mathcal N(0,I) N(0,I)中通过蒙特卡罗采样得到 ϵ \epsilon ϵ,隐变量 z z z 通过计算得到: z = ϵ μ ( x ) + σ ( x ) z=\epsilon\mu(x)+\sigma(x) z=ϵμ(x)+σ(x) μ ( x ) , σ ( x ) \mu(x),\sigma(x) μ(x),σ(x)为神经网络映射 g ( ϕ , x ) g(\phi,x) g(ϕ,x)所得分布的均值和方差。通过重参数化技巧,我们就可以使用反向传播算法进行迭代训练。

3.3 随机梯度变分贝叶斯估计

对于式(2)的变分下界,我们采用重参数化技巧及Monte Carlo估计得到

E q ϕ ( z ∣ x i ) [ f ( z ; θ ) ] = E p ( ϵ ) [ f ( g ϕ ( ϵ , x ( i ) ) ; θ ) ] ≃ 1 L ∑ l = 1 L f ( g ϕ ( ϵ ( l ) , x ( i ) ) ; θ ) E_{q_{\phi}(z|x^{i})}[f(z;\theta)]=E_{p(\epsilon)}[f(g_\phi(\epsilon, x^{(i)});\theta)]\simeq\frac{1}{L}\sum^L_{l=1}f(g_\phi(\epsilon^{(l)},x^{(i)});\theta) Eqϕ(zxi)[f(z;θ)]=Ep(ϵ)[f(gϕ(ϵ,x(i));θ)]L1l=1Lf(gϕ(ϵ(l),x(i));θ)

其中 ϵ ( l ) ∼ p ( ϵ ) \epsilon^{(l)} \sim p(\epsilon) ϵ(l)p(ϵ)
因此,我们可以根据式(2)得到第一个形式的随机梯度变分贝叶斯估计: L ~ A ( θ , ϕ , x ( i ) , z ( i ) ) \tilde{\mathcal{L}}^A(\theta, \phi, x^{(i)},z^{(i)}) L~A(θ,ϕ,x(i),z(i))

L ~ A ( θ , ϕ , x ( i ) , z ( i ) ) = 1 L ∑ l = 1 L log ⁡ p θ ( x ( i ) , z ( i , l ) ) − log ⁡ q ϕ ( z ( i , l ) ∣ x ( i ) ) \tilde{ \mathcal{L}}^A(\theta, \phi, x^{(i)},z^{(i)})=\frac{1}{L}\sum^L_{l=1}\log p_{\theta}(x^{(i)},z^{(i,l)})-\log q_{\phi}(z^{(i,l)}| x^{(i)}) L~A(θ,ϕ,x(i),z(i))=L1l=1Llogpθ(x(i),z(i,l))logqϕ(z(i,l)x(i))

其中 z ( i , l ) = g ϕ ( ϵ ( i . l ) , x ( i ) ) ,    ϵ ( l ) ∼ p ( ϵ ) z^{(i,l)}=g_\phi(\epsilon^{(i.l)},x^{(i)}),~~\epsilon^{(l)}\sim p(\epsilon) z(i,l)=gϕ(ϵ(i.l),x(i)),  ϵ(l)p(ϵ)
实际上,变分下界(3)中的第一项 D K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p ( z ) ) D_{KL}(q_{\phi}(z|x^{(i)})||p(z)) DKL(qϕ(zx(i))p(z)) 往往是可以直接求出来的,因此可以只对第二项 E q ϕ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] E_{q_{\phi}(z|x)}[log p_{\theta}(x|z)] Eqϕ(zx)[logpθ(xz)] 进行Monte Carlo估计,从而得到第二个形式的SGVB估计:

L ~ B ( θ , ϕ , x ( i ) , z ( i ) ) = − D K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p ( z ) ) + 1 L ∑ i = 1 L ( log ⁡ p θ ( x ( i ) ∣ z ( i , l ) ) ) \tilde{ \mathcal{L}}^B(\theta, \phi, x^{(i)},z^{(i)})=-D_{KL}(q_{\phi}(z|x^{(i)})||p(z))+\frac{1}{L}\sum^L_{i=1}(\log p_{\theta}(x^{(i)}|z^{(i,l)})) L~B(θ,ϕ,x(i),z(i))=DKL(qϕ(zx(i))p(z))+L1i=1L(logpθ(x(i)z(i,l)))

而这个由随机梯度变分贝叶斯估计得到的变分下界就是我们最终的损失函数。我们的优化目标函数由最大似然估计进行求解转化为对变分下界求最大值,此时保证了我们生成的后验分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)能够尽可能地接近真实分布 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx)

3.4 具体形式推导

现在,我们推导变分自编码器的具体形式。首先,假设隐变量 z z z服从标准高斯分布,即 p ( z ) = N ( 0 , I ) p(z)=\mathcal N(0,I) p(z)=N(0,I),继而假设由两个神经网络生成的后验分布 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(xz) q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)为多元高斯分布,即 p θ ( x ∣ z ) = N ( ( f ( z ; θ ) , σ 2 I ) ,    q ϕ ( z ∣ x ( i ) ) = N ( z ; μ i ( x ( i ) ) , σ i ( x ( i ) ) ) p_{\theta}(x|z)=\mathcal N((f(z;\theta),\sigma^2I),\ \ q_{\phi}(z|x^{(i)})=\mathcal N(z;\mu_i(x^{(i)}),\sigma_i(x^{(i)})) pθ(xz)=N((f(z;θ),σ2I),  qϕ(zx(i))=N(z;μi(x(i)),σi(x(i)))。设 z z z的维度为 J J J x x x的维度为 D D D
计算损失函数的第一项:

− D K L [ q ( z ∣ x ) ∣ ∣ p ( z ) ] = − ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( z )                                                     -D_{KL}[q(z|x)||p(z)] = -\int q(z|x)\log \frac{q(z|x)}{p(z)}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ DKL[q(zx)p(z)]=q(zx)logp(z)q(zx)                                                   
= − ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z + ∫ q ( z ∣ x ) log ⁡ p ( z ) d z                                  = -\int q(z|x)\log q(z|x)dz +\int q(z|x)\log {p(z)}dz\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =q(zx)logq(zx)dz+q(zx)logp(z)dz                                
= ∫ N ( z ; μ , σ 2 ) log ⁡ N ( z ; 0 , I ) d z − ∫ N ( z ; μ , σ 2 ) log ⁡ N ( z ; μ , σ 2 ) d z = \int \mathcal{N}(z;\mu,\sigma^{2})\log \mathcal{N}(z;0,I)dz - \int \mathcal{N}(z;\mu,\sigma^{2})\log \mathcal{N}(z;\mu,\sigma^2)dz =N(z;μ,σ2)logN(z;0,I)dzN(z;μ,σ2)logN(z;μ,σ2)dz
     = − J 2 log ⁡ ( 2 π ) − 1 2 ∑ j = 1 J ( μ j 2 + σ j 2 ) − ( − J 2 log ⁡ ( 2 π ) − 1 2 ∑ j = 1 J ( 1 + log ⁡ σ j 2 ) ) \ \ \ \ = -\frac{J}{2}\log (2\pi) -\frac{1}{2}\sum^J_{j=1}(\mu_j^2+\sigma_j^2)-(-\frac {J}{2}\log (2\pi) - \frac{1}{2} \sum^J_{j=1}(1+\log \sigma_j^2))     =2Jlog(2π)21j=1J(μj2+σj2)(2Jlog(2π)21j=1J(1+logσj2))
= 1 2 ∑ j = 1 J ( 1 + l o g ( ( σ j 2 ) ) − ( μ j 2 ) − ( σ j 2 ) )                                                   = \frac{1}{2}\sum^J_{j=1}(1 + log ((\sigma_j^2))- (\mu_j^2)- (\sigma_j^2))\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =21j=1J(1+log((σj2))(μj2)(σj2))                                                 

计算损失函数第二项:

− log ⁡ p θ ( x ∣ z ) = 1 Π d = 1 D 2 π σ 2 e ( − 1 2 ∣ ∣ x − f ( z ; θ ) σ ∣ ∣ 2 )                                  -\log p_{\theta}(x|z) = \frac{1}{\Pi^D_{d=1}\sqrt{2\pi \sigma^2}} e^{(-\frac{1}{2}||\frac{x-f(z;\theta)}{\sigma}||^2)}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ logpθ(xz)=Πd=1D2πσ2 1e(21σxf(z;θ)2)                                
                = 1 2 ∣ ∣ x − f ( z ; θ ) σ ∣ ∣ 2 + D 2 log ⁡ 2 π + D 2 l o g σ 2 \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\frac{1}{2} ||\frac{x-f(z;\theta)}{\sigma}||^2 + \frac{D}{2}\log 2\pi + \frac{D}{2}log \sigma^2                =21σxf(z;θ)2+2Dlog2π+2Dlogσ2
∼ 1 2 ∣ ∣ x − μ ( z ) σ ∣ ∣ 2                              \sim \frac{1}{2}||\frac{x-\mu(z)}{\sigma}||^2\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ 21σxμ(z)2                            

至此,我们已经得到最终损失函数,完成对变分自编码器的理论推导。


在这里插入图片描述

4.AEVB算法

经过上述推导,我们可以得到变分贝叶斯自编码器(AEVB)模型算法总结为:


Algorithm : Minibatch version of the Auto-Encoding VB (AEVB) algorithm.
θ \theta θ, ϕ \phi ϕ ← Initialize parameters
repeat
             X M \ \ \ \ \ \ \ \ \ \ \ \ X^M             XM ← \leftarrow Random minibatch of M datapoints (drawn from full dataset)
             ϵ \ \ \ \ \ \ \ \ \ \ \ \ \epsilon             ϵ ← \leftarrow Random samples from noise distribution p ( ϵ ) p(\epsilon) p(ϵ)
             g \ \ \ \ \ \ \ \ \ \ \ \ g             g ← ∇ θ , ϕ L ~ M ( θ , ϕ ; X M , ϵ ) \leftarrow \nabla_{\theta,\phi}\widetilde{L}^{M}(\theta,\phi;X^M,\epsilon) θ,ϕL M(θ,ϕ;XM,ϵ) (Gradients of minibatch estimator)
             θ , ϕ \ \ \ \ \ \ \ \ \ \ \ \ \theta,\phi             θ,ϕ ← \leftarrow Update parameters using gradients g
until convergence of parameters ( θ , ϕ \theta,\phi θ,ϕ)
return θ , ϕ \theta,\phi θ,ϕ

代码实现

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2_mean = nn.Linear(400, 20)
        self.fc2_logvar = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mean(h1), self.fc2_logvar(h1)

    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size()) * std + mu
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        return self.decode(z), mu, logvar

5.变分自编码器模型分析

从上述的推导中,我们发现,变分自编码器的关键思想在于:它们可以通过最大化与数据点 x x x相关联的变分下界 L ( q ) L(q) L(q)来训练:

L ( q ) = E z ~ q ϕ ( z ∣ x ) l o g   p θ ( z , x ) + H ( ( q ϕ ( z ∣ x ) )    \mathcal L(q)=E_{z~q_{\phi}(z|x)}log\ p_{\theta}(z,x)+\mathcal H((q_{\phi}(z|x))\ \ L(q)=Ezqϕ(zx)log pθ(z,x)+H((qϕ(zx))  
                     = E z ~ q ϕ ( z ∣ x ) l o g   p θ ( x ∣ z ) − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =E_{z~q_{\phi}(z|x)}log\ p_{\theta}(x|z)-D_{KL}(q_{\phi}(z|x)||p_{\theta}(z))                     =Ezqϕ(zx)log pθ(xz)DKL(qϕ(zx)pθ(z))
≤ l o g p θ ( x )                                         \leq log p_{\theta}(x)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ logpθ(x)                                       

在上式的第一行 中,我们将第一项视为潜变量的近似后验下可见和隐藏变量的联合对数似然性(正如 EM 一样,不同的是我们使用近似而不是精确后验)。第二项则可视为近似后验的熵。当 q ( z ) q(z) q(z) 被选择为高斯分布,其中噪声被添加到预测平均值时,最大化该熵项促使该噪声标准偏差的增加。更一般地,这个熵项鼓励变分后验将高概率质量置于可能已经产生 x x x 的许多 z z z 值上,而不是坍缩到单个估计最可能值的点。在上式第二行中,我们将第一项视为在其他自编码器中出现的重构对数似然。第二项试图使近似后验分布 q ( z ∣ x ) q(z | x) q(zx) 和模型先验 p θ ( z ) p_{\theta}(z) pθ(z)彼此接近。
变分推断和学习的传统方法是通过优化算法推断 q q q,通常是迭代不动点方程 。这些方法是缓慢的,并且通常需要以闭解形式计算 E z ∼ q   l o g   p θ ( z , x ) E_{z∼q} \ log \ p_{\theta}(z, x) Ezq log pθ(z,x)
变分自编码器背后的主要思想是训练产生 q q q参数的参数编码器(有时也称为推断网络或识别模型)。只要隐变量 z z z是连续变量,我们就可以通过从 q ( z ∣ x ) = q ( z ; f ( x ; θ ) ) q(z | x) = q(z; f (x; θ)) q(zx)=q(z;f(x;θ))中采样 z 的样本并进行反向传播,以获得相对于 θ \theta θ的梯度。学习则仅包括相对于编码器和解码器的参数最大化 L L L, L L L中的所有期望都可以通过蒙特卡罗采样来近似。

6.基于MNIST数据集的实验及结果分析

将模型在MNIST数据集上训练20轮,得到结果如图所示。
在这里插入图片描述

我们可以看到,虽然输出的数字结构正确,但是图片比较模糊,边缘线条有残缺。于是,我们可以得出,变分自编码成功地近似了输入数据的分布并且能够输出新的新样本数据。然而重构图像有些模糊,边缘不够平滑,噪声干扰大,说明映射函数表达能力可能有些欠缺,推断所得分布与原始数据分布的近似仍有一定偏差。
一种对模糊性可能的解释是,模糊性是最大似然的固有效应,因为我们需要最小化 D K L ( p d a t a ∣ ∣ p m o d e l ) D_{KL}(p_{data}||p_{model}) DKL(pdatapmodel)。这就意味着模型将为训练集中出现的点分配高的概率,但也可能为其他点分配高的概率。并且,最大化这种分布似然性以为这倾向于忽略由少量像素表示的特征或其中亮度变化微小的像素。
在这里插入图片描述

7.生成模型的拓展-GAN

7.1 GAN原理

从上述分析我们知道,变分自编码器虽然能近似推断输入数据的分布,但是由于KL散度的度量,最大似然估计可能存在固有的模糊性,所以重构图像不够平滑,噪声较多,而且映射函数表达能力欠缺,一旦输入数据分布非常复杂,就难以较好的拟合。而GAN(Generative Adversarial Networks, GAN)的改进之处就在于度量的选择和模型复杂性的增加。我们知道VAE使用变分贝叶斯方法,具体来讲是KL散度来度量两个分布之间的差距,那么GAN的主要想法在于,既然没有合适的度量,那么就通过训练一个神经网络把合适的度量给训练出来。
在这里插入图片描
生成对抗网络(GAN)也由两部分组成,一个是生成器,输入一个随机噪声,生成一张图片。另一个是判别器,用来判断输入的图片是真图片还是假图片。生成器的目标是尽可能生成能以假乱真的图片,让判别器以为这是真的图片,训练生成器时只用噪声生成假图片。而判别器的目标是尽可能将生成器生成的假图片于真实图片区分开,训练时需要利用生成器生成的假图片和真实的图片。判别器用来评估生成的假图片的质量,促使生成器相应地调整参数。

7.2 GAN代码实现
from torch import nn


class NetG(nn.Module):
    """
    生成器定义
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map数

        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 96 x 96
        )

    def forward(self, input):
        return self.main(input)


class NetD(nn.Module):
    """
    判别器定义
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 输出一个数(概率)
        )

    def forward(self, input):
        return self.main(input).view(-1)

7.3 使用GAN进行实验
7.3.1 MNIST数据生成

经过20轮训练所得到的数据
在这里插入图片描述

7.3.2 基于GAN的动漫图片生成

输入数据图片(RGB三通道)
在这里插入图片描述

经过20轮训练所得到的输出数据图片(RGB三通道)
在这里插入图片描述
经过200轮训练生成的图片(灰度值,减少计算量)
在这里插入图片描述

7.4 GAN实验结果分析

在MNIST数据集上实验的结果,对比VAE我们可以看到GAN生成的新样本数据更加清晰,图像边缘更加平滑。说明对输入分布拟合地更好,神经网络的表示能力更强。
而从动漫图像生成实验我们可以看出,GAN对更复杂的输入分布也有较好的推断。而且当训练轮数增加到40甚至200之后,生成的图片细节已经非常完善,线条更流畅,轮廓更清晰,不少生产出来的新数据已经能够以假乱真了。

8 总结

如何用有向概率模型从有限的输入数据推断出数据的分布并生成新的样本数据?变分自编码器基于贝叶斯变分法,通过对连续的隐变量采样,训练模型从而得到输入数据后验分布的近似推断。通过对MNIST数据集上的实验表明,VAE确实能近似地学习出输入分布并生成新的样本数据。然而,生成的图片数据虽然结构较输入数据差不多,但是图像有些模糊,边缘不平滑,线条不流畅。于是,相较于VAE,GAN又对合适的度量进行学习,能对更复杂的输入分布进行推断。通过MNIST数据集上的实验得出的结果明显优于VAE。而对与更复杂的分布推断任务,像本文中的动漫图像生成实验,我们所得到的生成图片细节已经非常完善,线条流畅,轮廓清晰,并且一些图片已经达到了一定以假乱真的地步。当我们使用更深层的GAN模型时,还会有更强的推断能力和生成能力,能学习并生成更复杂的图片。

9 代码

9.1 变分自编码器实现及在MNIST数据集上的实验(基于PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image


def loss_function(recon_x, x, mu, logvar):
   
    BCE_loss = nn.BCELoss(reduction='sum')
    reconstruction_loss = BCE_loss(recon_x, x)
    KL_divergence = -0.5 * torch.sum(1+logvar-torch.exp(logvar)-mu**2)
    #KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    #KLD = torch.sum(KLD_ele).mul_(-0.5)
    print(reconstruction_loss, KL_divergence)

    return reconstruction_loss + KL_divergence


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2_mean = nn.Linear(400, 20)
        self.fc2_logvar = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mean(h1), self.fc2_logvar(h1)

    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size()) * std + mu
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        return self.decode(z), mu, logvar


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)

# Training
def train(epoch):
    vae.train()
    all_loss = 0.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to('cpu'), targets.to('cpu')
        real_imgs = torch.flatten(inputs, start_dim=1)

        # Train Discriminator
        gen_imgs, mu, logvar = vae(real_imgs)
        loss = loss_function(gen_imgs, real_imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_loss += loss.item()
        print('Epoch {}, loss: {:.6f}'.format(epoch, all_loss/(batch_idx+1)))
        # Save generated images for every epoch
    fake_images = gen_imgs.view(-1, 1, 28, 28)
    save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))



for epoch in range(20):
    train(epoch)

torch.save(vae.state_dict(), './vae.pth')



9.2 GAN生成漫画图像(基于PyTorch搭建)
'''model.py'''
from torch import nn


class NetG(nn.Module):
    """
    生成器定义
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map数

        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 96 x 96
        )

    def forward(self, input):
        return self.main(input)


class NetD(nn.Module):
    """
    判别器定义
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 输出一个数(概率)
        )

    def forward(self, input):
        return self.main(input).view(-1)

''' visualize.py'''
from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np


class Visualizer():
    """
    封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
    调用原生的visdom接口
    """

    def __init__(self, env='default', **kwargs):
        import visdom
        self.vis = visdom.Visdom(env=env, use_incoming_socket=False,**kwargs)

        # 画的第几个数,相当于横座标
        # 保存(’loss',23) 即loss的第23个点
        self.index = {}
        self.log_text = ''

    def reinit(self, env='default', **kwargs):
        """
        修改visdom的配置
        """
        self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs)
        return self

    def plot_many(self, d):
        """
        一次plot多个
        @params d: dict (name,value) i.e. ('loss',0.11)
        """
        for k, v in d.items():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.items():
            self.img(k, v)

    def plot(self, name, y):
        """
        self.plot('loss',1.00)
        """
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]), X=np.array([x]),
                      win=(name),
                      opts=dict(title=name),
                      update=None if x == 0 else 'append'
                      )
        self.index[name] = x + 1

    def img(self, name, img_):
        """
        self.img('input_img',t.Tensor(64,64))
        """

        if len(img_.size()) < 3:
            img_ = img_.cpu().unsqueeze(0)
        self.vis.image(img_.cpu(),
                       win=(name),
                       opts=dict(title=name)
                       )

    def img_grid_many(self, d):
        for k, v in d.items():
            self.img_grid(k, v)

    def img_grid(self, name, input_3d):
        """
        一个batch的图片转成一个网格图,i.e. input(36,64,64)
        会变成 6*6 的网格图,每个格子大小64*64
        """
        self.img(name, tv.utils.make_grid(
            input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0)))

    def log(self, info, win='log_text'):
        """
        self.log({'loss':1,'lr':0.0001})
        """

        self.log_text += ('[{time}] {info} <br>'.format(
            time=time.strftime('%m%d_%H%M%S'),
            info=info))
        self.vis.text(self.log_text, win=win)

    def __getattr__(self, name):
        return getattr(self.vis, name)

'''main.py'''
import os
import torch as t
import torchvision as tv
import tqdm
from model import NetG, NetD
from torchnet.meter import AverageValueMeter


class Config(object):
    data_path = 'data/'  # 数据集存放路径
    num_workers = 4  # 多进程加载数据所用的进程数
    image_size = 96  # 图片尺寸
    batch_size = 256
    max_epoch = 200
    lr1 = 2e-4  # 生成器的学习率
    lr2 = 2e-4  # 判别器的学习率
    beta1 = 0.5  # Adam优化器的beta1参数
    gpu = False  # 是否使用GPU
    nz = 100  # 噪声维度
    ngf = 64  # 生成器feature map数
    ndf = 64  # 判别器feature map数

    save_path = 'imgs/'  # 生成图片保存路径

    vis = True  # 是否使用visdom可视化
    env = 'GAN'  # visdom的env
    plot_every = 20  # 每间隔20 batch,visdom画图一次

    debug_file = '/tmp/debuggan'  # 存在该文件则进入debug模式
    d_every = 1  # 每1个batch训练一次判别器
    g_every = 5  # 每5个batch训练一次生成器
    save_every = 10  # 没10个epoch保存一次模型
    netd_path = None  # 'checkpoints/netd_.pth' #预训练模型
    netg_path = None  # 'checkpoints/netg_211.pth'

    # 只测试不训练
    gen_img = 'result.png'
    # 从512张生成的图片中保存最好的64张
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪声的均值
    gen_std = 1  # 噪声的方差


opt = Config()


def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()


@t.no_grad()
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))


if __name__ == '__main__':
    import fire
    fire.Fire()

10.参考文献

Abadi, M. and Andersen, D. G. (2016). Learning to protect communications with adversarial neural cryptography. arXiv preprint arXiv:1610.06918.
Bengio, Y., Thibodeau-Laufer, E., Alain, G., and Yosinski, J. (2014). Deep generative stochastic networks trainable by backprop. In ICML’2014.
Brock, A., Lim, T., Ritchie, J. M., and Weston, N. (2016). Neural photo editing with introspective adversarial networks. CoRR, abs/1609.07093.
Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I., and Abbeel, P. (2016a). Infogan: Interpretable representation learning by information maximizing generative adversarial nets. In Advances in Neural Information Processing Systems, pages 2172–2180.
Chen, X., Kingma, D. P., Salimans, T., Duan, Y., Dhariwal, P., Schulman, J., Sutskever, I., and Abbeel, P. (2016b). Variational lossy autoencoder. arXiv preprint arXiv:1611.02731.
Deco, G. and Brauer, W. (1995). Higher order statistical decorrelation without information loss. NIPS.
Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. (2009). Ima- geNet: A Large-Scale Hierarchical Image Database. In CVPR09.
Deng, J., Berg, A. C., Li, K., and Fei-Fei, L. (2010). What does classifying more than 10,000 image categories tell us? In Proceedings of the 11th European Conference on Computer Vision: Part V, ECCV’10, pages 71–84, Berlin, Heidelberg. Springer-Verlag.
Denton, E., Chintala, S., Szlam, A., and Fergus, R. (2015). Deep generative image models using a Laplacian pyramid of adversarial networks. NIPS.
Dumoulin, V., Belghazi, I., Poole, B., Lamb, A., Arjovsky, M., Mastropietro, O., and Courville, A. (2016). Adversarially learned inference. arXiv preprint arXiv:1606.00704 .
Dziugaite, G. K., Roy, D. M., and Ghahramani, Z. (2015). Training generative neural networks via maximum mean discrepancy optimization. arXiv preprint arXiv:1505.03906 .
Finn, C. and Levine, S. (2016). Deep visual foresight for planning robot motion. arXiv preprint arXiv:1610.00696.
Finn, C., Christiano, P., Abbeel, P., and Levine, S. (2016a). A connection between generative adversarial networks, inverse reinforcement learning, and energy-based models. arXiv preprint arXiv:1611.03852.
Finn, C., Goodfellow, I., and Levine, S. (2016b). Unsupervised learning for physical interaction through video prediction. NIPS.
Frey, B. J. (1998). Graphical models for machine learning and digital commu- nication. MIT Press.
Frey, B. J., Hinton, G. E., and Dayan, P. (1996). Does the wake-sleep algorithm learn good density estimators? In D. Touretzky, M. Mozer, and M. Hasselmo, editors, Advances in Neural Information Processing Systems 8 (NIPS’95), pages 661–670. MIT Press, Cambridge, MA.
Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., and Lempitsky, V. (2015). Domain-adversarial training of neural networks. arXiv preprint arXiv:1505.07818.
Goodfellow, I., Bengio, Y., and Courville, A. (2016). Deep Learning. MIT Press. http://www.deeplearningbook.org.
Goodfellow, I. J. (2014). On distinguishability criteria for estimating generative models. In International Conference on Learning Representations, Workshops Track .
Goodfellow, I. J., Shlens, J., and Szegedy, C. (2014a). Explaining and harnessing adversarial examples. CoRR, abs/1412.6572.
Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. (2014b). Generative adversarial net- works. In NIPS’2014.
Ho, J. and Ermon, S. (2016). Generative adversarial imitation learning. In Advances in Neural Information Processing Systems, pages 4565–4573.
Ioffe, S. and Szegedy, C. (2015). Batch normalization: Accelerating deep net- work training by reducing internal covariate shift.
Isola, P., Zhu, J.-Y., Zhou, T., and Efros, A. A. (2016). Image-to-image transla- tion with conditional adversarial networks. arXiv preprint arXiv:1611.07004 .
Jang, E., Gu, S., and Poole, B. (2016). Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144
Kingma, D. and Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
Kingma, D. P. (2013). Fast gradient-based inference with continuous latent variable models in auxiliary form. Technical report, arxiv:1306.0733.
Kingma, D. P., Salimans, T., and Welling, M. (2016). Improving variational inference with inverse autoregressive flow. NIPS.
Ledig, C., Theis, L., Huszar, F., Caballero, J., Aitken, A. P., Tejani, A., Totz, J., Wang, Z., and Shi, W. (2016). Photo-realistic single image super-resolution using a generative adversarial network. CoRR, abs/1609.04802.
Li, Y., Swersky, K., and Zemel, R. S. (2015). Generative moment matching networks. CoRR, abs/1502.02761.
Lotter, W., Kreiman, G., and Cox, D. (2015). Unsupervised learning of visual structure using predictive generative networks. arXiv preprint arXiv:1511.06380 .
Maddison, C. J., Mnih, A., and Teh, Y. W. (2016). The concrete distribu- tion: A continuous relaxation of discrete random variables. arXiv preprint arXiv:1611.00712 .
Metz, L., Poole, B., Pfau, D., and Sohl-Dickstein, J. (2016). Unrolled generative adversarial networks. arXiv preprint arXiv:1611.02163.
Nguyen, A., Yosinski, J., Bengio, Y., Dosovitskiy, A., and Clune, J. (2016). Plug & play generative networks: Conditional iterative generation of images in latent space. arXiv preprint arXiv:1612.00005.
Nowozin, S., Cseke, B., and Tomioka, R. (2016). f-gan: Training generative neural samplers using variational divergence minimization. arXiv preprint arXiv:1606.00709 .
Odena, A. (2016). Semi-supervised learning with generative adversarial net- works. arXiv preprint arXiv:1606.01583.
Oord, A. v. d., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A., and Kavukcuoglu, K. (2016). Wavenet: A generative model for raw audio. arXiv preprint arXiv:1609.03499.
Pfau, D. and Vinyals, O. (2016). Connecting generative adversarial networks and actor-critic methods. arXiv preprint arXiv:1610.01945.
Radford, A., Metz, L., and Chintala, S. (2015). Unsupervised representa- tion learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434.
Ratliff, L. J., Burden, S. A., and Sastry, S. S. (2013). Characterization and computation of local nash equilibria in continuous games. In Communication, Control, and Computing (Allerton), 2013 51st Annual Allerton Conference on, pages 917–924. IEEE.
Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. (2016). Improved techniques for training gans. In Advances in Neural Information Processing Systems, pages 2226–2234.
Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., Van Den Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., et al. (2016). Mastering the game of go with deep neural networks and tree search. Nature, 529(7587), 484–489.
Springenberg, J. T. (2015). Unsupervised and semi-supervised learning with categorical generative adversarial networks. arXiv preprint arXiv:1511.06390 .
Springenberg, J. T., Dosovitskiy, A., Brox, T., and Riedmiller, M. (2015). Striv- ing for simplicity: The all convolutional net. In ICLR.
Szegedy, C., Zaremba, W., Sutskever, I., Bruna, J., Erhan, D., Goodfellow, I. J., and Fergus, R. (2014). Intriguing properties of neural networks. ICLR, abs/1312.6199.
Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., and Wojna, Z. (2015). Re- thinking the Inception Architecture for Computer Vision. ArXiv e-prints.
Zhang, H., Xu, T., Li, H., Zhang, S., Huang, X., Wang, X., and Metaxas, D. (2016). Stackgan: Text to photo-realistic image synthesis with stacked generative adversarial networks. arXiv preprint arXiv:1612.03242.
Zhu, J.-Y., Kr ̈ahenbu ̈hl, P., Shechtman, E., and Efros, A. A. (2016). Generative visual manipulation on the natural image manifold. In European Conference on Computer Vision, pages 597–613. Springer.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值