vae公式推导+源码
文章目录
一、AE 自编码器
编码器网络可以将原始高维网络转换为潜在的低维代码
解码器网络可以从低维代码中恢复原始数据,并且可能具有越来越大的输出层
自编码器针对从代码重建数据进行了显式优化。一个好的中间表示不仅可以捕获潜在变量,而且有利于完整的解压缩过程。
变分编码器和自动编码器的区别就在于:传统自动编码器的隐变量 z z z的分布是不知道的,因此我们无法采样得到新的 z z z,也就无法通过解码器得到新的 x x x。
AE的缺点:映射空间不连续,无规则,无界
VAE将每组数据编码为一个分布
二、AEs 正则自编码器
1. 去噪自编码器
在输入数据时加入噪声,强化特征提取能力
2. 稀疏自编码器
dropout
3. 对抗式自编码器
与GAN网络结合
三、VAE 变分自编码器
1. 数学知识
1.1 贝叶斯公式(Bayes Rule)
公式表述为:
p
(
z
∣
x
)
=
p
(
x
,
z
)
p
(
x
)
=
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
(1)
p(z|x) = \frac{p(x,z)}{p(x)} = \frac{p(x|z)p(z)}{p(x)} \tag{1}
p(z∣x)=p(x)p(x,z)=p(x)p(x∣z)p(z)(1)
1.2 KL散度
K − L K-L K−L 散度又被称为相对熵(relative entropy),是对两个概率分布间差异的非对称性度量。
假设 p ( x ) p(x) p(x) , q ( x ) q(x) q(x)是随机变量上的两个概率分布,
在离散随机变量的情况下,相对熵的定义为:
K
L
(
(
p
(
x
)
∣
∣
q
(
x
)
)
)
=
∑
p
(
x
)
log
p
(
x
)
q
(
x
)
(2)
KL((p(x)||q(x))) = \sum p(x) \log{\frac{p(x)}{q(x)}} \tag{2}
KL((p(x)∣∣q(x)))=∑p(x)logq(x)p(x)(2)
在连续随机变量的情况下,相对熵的定义为:
K
L
(
(
p
(
x
)
∣
∣
q
(
x
)
)
)
=
∫
p
(
x
)
log
p
(
x
)
q
(
x
)
d
x
(3)
KL((p(x)||q(x))) = \int p(x) \log{\frac{p(x)}{q(x)}}dx \tag{3}
KL((p(x)∣∣q(x)))=∫p(x)logq(x)p(x)dx(3)
1.3 EM算法
EM算法(期望最大算法)是一种迭代算法,用于含有隐变量的概率参数模型的最大似然估计或极大后验概率估计。具体思想如下:
EM算法的核心思想非常简单,分为两步:Expection-Step和Maximization-Step
E-Step主要通过观察数据和现有模型来估计参数,然后用这个估计的参数值来计算似然函数的期望值
M-Step是寻找似然函数最大化时对应的参数。由于算法会保证在每次迭代之后似然函数都会增加,所以函数最终会收敛。
公式表述为:
ln
p
θ
(
x
)
=
K
L
(
q
(
z
)
∣
∣
p
θ
(
z
∣
x
)
)
+
∫
q
(
z
)
ln
p
θ
(
x
,
z
)
q
(
z
)
d
z
(4)
\ln{p_{\theta}(x)} = KL(q(z)||p_{\theta}(z|x)) + \int q(z) \ln{\frac{p_{\theta}(x,z)}{q(z)}}dz \tag{4}
lnpθ(x)=KL(q(z)∣∣pθ(z∣x))+∫q(z)lnq(z)pθ(x,z)dz(4)
= K L ( q ( z ) ∣ ∣ p θ ( z ∣ x ) ) + L θ ( q , x ) = KL(q(z)||p_{\theta}(z|x)) + L_{\theta}(q,x) =KL(q(z)∣∣pθ(z∣x))+Lθ(q,x)
E-step:用来固定
θ
\theta
θ,求
q
(
z
)
q(z)
q(z):
q
t
+
1
(
z
)
=
a
r
g
m
a
x
q
L
θ
t
(
q
,
x
)
(5)
q_{t+1}(z) = \mathop {argmax}_{q}L_{\theta_{t}}(q,x) \tag{5}
qt+1(z)=argmaxqLθt(q,x)(5)
如果 p θ ( z ∣ x ) ) p_{\theta}(z|x)) pθ(z∣x)) 可以得出,则 q ( z ) q(z) q(z) 就等于 p θ ( z ∣ x ) ) p_{\theta}(z|x)) pθ(z∣x)) ,如果不能得出,就利用变分推断近似估计 q ( z ) q(z) q(z)
M-step:用来固定 q ( z ) q(z) q(z) ,求 θ \theta θ:
θ t + 1 = a r g m a x θ L θ ( q t + 1 , x ) (6) \theta_{t+1} = \mathop {argmax}_{\theta}L_{\theta}(q_{t+1},x) \tag{6} θt+1=argmaxθLθ(qt+1,x)(6)
1.4 变分推断
参数估计:根据样本中提供的相关信息,对总体分布中的未知参数 z z z进行估值
使用贝叶斯估计对未知参数估值:
p
(
z
∣
x
)
=
p
(
x
,
z
)
p
(
x
)
=
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
(上述公式1)
p(z|x) = \frac{p(x,z)}{p(x)} = \frac{p(x|z)p(z)}{p(x)} \tag{上述公式1}
p(z∣x)=p(x)p(x,z)=p(x)p(x∣z)p(z)(上述公式1)
其中,
p
(
x
∣
z
)
p(x|z)
p(x∣z) 为极大似然估计,
p
(
z
)
p(z)
p(z) 为最大后验估计。然而贝叶斯估计中的分母
p
(
x
)
p(x)
p(x) 一般得不出结果,只能找一个与之相似的函数 $q(z) \approx p(z|x) $ 代替,并通过
K
−
L
K-L
K−L散度求出函数间的近似程度。
目标函数:
min
K
L
(
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
)
(2)
\min KL((q(z)||p(z|x))) \tag{2}
minKL((q(z)∣∣p(z∣x)))(2)
则:
K
L
(
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
)
=
∫
q
(
z
)
log
q
(
z
)
p
(
z
∣
x
)
d
z
=
∫
q
(
z
)
log
q
(
z
)
p
(
x
,
z
)
d
z
+
ln
p
(
x
)
(3)
KL((q(z)||p(z|x))) = \int q(z) \log{\frac{q(z)}{p(z|x)}}dz \tag{3} =\int q(z) \log{\frac{q(z)}{p(x,z)}}dz + \ln{p(x)}
KL((q(z)∣∣p(z∣x)))=∫q(z)logp(z∣x)q(z)dz=∫q(z)logp(x,z)q(z)dz+lnp(x)(3)
ln p ( x ) = K L ( ( q ( z ) ∣ ∣ p ( z ∣ x ) ) ) + ∫ q ( z ) log p ( x , z ) q ( z ) d z = K L ( ( q ( z ) ∣ ∣ p ( z ∣ x ) ) ) + L ( q ) (4) \ln{p(x)} = KL((q(z)||p(z|x))) + \int q(z) \log{\frac{p(x,z)}{q(z)}}dz \tag{4} = KL((q(z)||p(z|x))) + L(q) lnp(x)=KL((q(z)∣∣p(z∣x)))+∫q(z)logq(z)p(x,z)dz=KL((q(z)∣∣p(z∣x)))+L(q)(4)
因为
ln
p
(
x
)
\ln{p(x)}
lnp(x)是与
z
z
z无关的常量,并不改变,所以要求的目标函数
min
K
L
(
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
)
\min KL((q(z)||p(z|x)))
minKL((q(z)∣∣p(z∣x))) 相当于
max
L
(
q
)
\max L(q)
maxL(q) ,而变分贝叶斯学习可以通过
q
(
z
)
q(z)
q(z) 的迭代实现
L
(
q
)
L(q)
L(q) 的最大化,即
E
L
B
O
ELBO
ELBO变分下界,则:
max
L
(
q
)
=
∫
q
(
z
)
log
p
(
x
,
z
)
d
z
−
∫
q
(
z
)
log
q
(
z
)
d
z
(5)
\max L(q) = \int q(z) \log{p(x,z)}dz - \int q(z) \log{q(z)}dz \tag{5}
maxL(q)=∫q(z)logp(x,z)dz−∫q(z)logq(z)dz(5)
平均场理论:变分分布
q
(
z
)
q(z)
q(z)可以因式分解为
M
M
M个互不相交的部分
q
(
z
)
=
∏
i
=
I
M
q
i
(6)
q(z) = \prod_{i=I}^{M}{q_i} \tag{6}
q(z)=i=I∏Mqi(6)
则:
E L B O = max L ( q ) = ∫ q ( z ) log p ( x , z ) q ( z ) d z (7) ELBO = \max L(q) = \int q(z) \log{\frac{p(x,z)}{q(z)}}dz \tag{7} ELBO=maxL(q)=∫q(z)logq(z)p(x,z)dz(7)
= ∫ z ∏ i = 1 M q i [ log p ( x , z ) − log ∏ k = 1 M q k ] d z = \int_z \prod_{i=1}^{M}{q_i}[ \log{p(x,z)} - \log{\prod_{k=1}^{M}q_{k}}]dz =∫zi=1∏Mqi[logp(x,z)−logk=1∏Mqk]dz
= ∫ z q j [ ∏ i ≠ j M q i log p ( x , z ) ] d z − ∫ z ∏ i = 0 M q i ∑ k = 1 M log q k d z = \int_z q_j [\prod_{i \neq j}^{M}{q_i} \log{p(x,z)}]dz - \int_{z}\prod_{i=0}^{M}q_{i}\sum_{k=1}^{M}\log{q_k}dz =∫zqj[i=j∏Mqilogp(x,z)]dz−∫zi=0∏Mqik=1∑Mlogqkdz
分别对➖左右进行计算
L
e
f
t
=
∫
z
q
j
[
∏
i
≠
j
M
q
i
log
p
(
x
,
z
)
]
d
z
(8)
Left = \int_z q_j [\prod_{i \neq j}^{M}{q_i} \log{p(x,z)}]dz \tag{8}
Left=∫zqj[i=j∏Mqilogp(x,z)]dz(8)
= ∫ z j q j [ ∫ z i ≠ j ∏ i ≠ j q i log p ( x , z ) d z i ≠ j ] d z j = \int_{z_j}q_j[\int_{z_{i \neq j}} \prod_{i \neq j}q_i\log{p(x,z)}dz_{i \neq j}]dz_j =∫zjqj[∫zi=ji=j∏qilogp(x,z)dzi=j]dzj
= ∫ z j q j E i ≠ j [ log p ( x , z ) ] d z j = \int_{z_j} q_j \mathbb{E}_{i \neq j}[\log{p(x,z)}]dz_j =∫zjqjEi=j[logp(x,z)]dzj
R i g h t = ∫ z ∏ i = 0 M q i ∑ k = 1 M log q k d z + C 1 (9) Right = \int_{z}\prod_{i=0}^{M}q_{i}\sum_{k=1}^{M}\log{q_k}dz + C1 \tag{9} Right=∫zi=0∏Mqik=1∑Mlogqkdz+C1(9)
= ∫ z j q j log q j d z j ∏ i ≠ j ∫ z i q i d z i + C 1 = \int_{z_j} q_j\log{q_j}dz_j \prod_{i \neq j}\int_{z_i}q_i dz_i + C1 =∫zjqjlogqjdzji=j∏∫ziqidzi+C1
= ∫ z j q j log q j d z j + C 1 = \int_{z_j} q_j\log{q_j}dz_j + C1 =∫zjqjlogqjdzj+C1
其中: C 1 = − ∫ z ∏ i = 0 M q i ∑ i ≠ j log q j d z C1 = - \int_z \prod_{i=0}^{M}q_i\sum{i \neq j}\log{q_j}dz C1=−∫z∏i=0Mqi∑i=jlogqjdz
同时,令 log p ~ ( x , z j ) = E i ≠ j [ log p ( x , z ) ] + C 2 \log{\tilde{p}(x,z_j)} = \mathbb{E}_{i \neq j}[\log{p(x,z)}]+C2 logp~(x,zj)=Ei=j[logp(x,z)]+C2 ( C 2 C2 C2用于归一化,等式左边对一个概率取对数,因此需要保证概率的性质)
则
E
L
B
O
ELBO
ELBO最终化为:
E
L
B
O
=
∫
z
j
q
j
E
i
≠
j
[
log
p
(
x
,
z
)
]
d
z
j
−
∫
z
j
q
j
log
q
j
d
z
j
−
C
1
(10)
ELBO = \int_{z_j} q_j \mathbb{E}_{i \neq j}[\log{p(x,z)}]dz_j - \int_{z_j} q_j\log{q_j}dz_j - C1 \tag{10}
ELBO=∫zjqjEi=j[logp(x,z)]dzj−∫zjqjlogqjdzj−C1(10)
= ∫ z j q j log p ~ ( x , z j ) q j d z j + C 3 = \int_{z_j} q_j\log{\frac{\tilde{p}(x,z_j)}{q_j}}dz_j + C3 =∫zjqjlogqjp~(x,zj)dzj+C3
= − K L ( q j ∣ ∣ p ~ ( x , z j ) ) + C 3 = - KL(q_j||\tilde{p}(x,z_j)) + C3 =−KL(qj∣∣p~(x,zj))+C3
则当
q
j
→
p
~
(
x
,
z
j
)
q_j \rightarrow \tilde{p}(x,z_j)
qj→p~(x,zj)时,
E
L
B
O
ELBO
ELBO最大,即
L
(
q
)
L(q)
L(q)取最大值
log
q
j
(
z
j
)
=
log
p
~
(
x
,
z
j
)
(11)
\log{q_j(z_j)} = \log{\tilde{p}(x,z_j)} \tag{11}
logqj(zj)=logp~(x,zj)(11)
= E i ≠ j [ log p ( x , z ) ] + C 2 = \mathbb{E}_{i \neq j}[\log{p(x,z)}]+C2 =Ei=j[logp(x,z)]+C2
则:
q
j
(
z
j
)
=
exp
(
E
i
≠
j
[
log
p
(
x
,
z
)
]
+
C
2
)
(12)
q_j(z_j) = \exp{( \mathbb{E}_{i \neq j}[\log{p(x,z)}]+C2)} \tag{12}
qj(zj)=exp(Ei=j[logp(x,z)]+C2)(12)
∫ z j q j ( z j ) d z j = exp ( C 2 ) ∫ z j exp ( E i ≠ j [ log p ( x , z ) ] ) d z j \int_{z_j}q_j(z_j)dz_j = \exp{(C2)}\int_{z_j}\exp{(\mathbb{E}_{i \neq j}[\log{p(x,z)}])}dz_j ∫zjqj(zj)dzj=exp(C2)∫zjexp(Ei=j[logp(x,z)])dzj
可以求出:
C
2
=
log
1
∫
z
j
exp
(
E
i
≠
j
[
log
p
(
x
,
z
)
]
)
d
z
j
(13)
C2 = \log{\frac{1}{\int_{z_j}\exp{(\mathbb{E}_{i \neq j}[\log{p(x,z)}])}dz_j}} \tag{13}
C2=log∫zjexp(Ei=j[logp(x,z)])dzj1(13)
返回带入公式12,得出:
q
j
(
z
j
)
=
exp
(
E
i
≠
j
[
log
p
(
x
,
z
)
]
)
∫
z
j
exp
(
E
i
≠
j
[
log
p
(
x
,
z
)
]
)
d
z
j
(14)
q_j(z_j) = \frac{\exp{( \mathbb{E}_{i \neq j}[\log{p(x,z)}])}}{\int_{z_j}\exp{(\mathbb{E}_{i \neq j}[\log{p(x,z)}])}dz_j} \tag{14}
qj(zj)=∫zjexp(Ei=j[logp(x,z)])dzjexp(Ei=j[logp(x,z)])(14)
2. 公式推导
参数 | 含义 | 数学表示 | 含义 |
---|---|---|---|
x x x | 观测数据 | q ϕ ( z ∣ x ) / q ϕ z q_{\phi}(z|x)/q_{\phi}z qϕ(z∣x)/qϕz | 使用变分参数 ϕ \phi ϕ 来近似真实后验 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x) |
z z z | 隐变量 | p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x) | 给定数据 x x x 后隐变量 z z z 的真实后验分布 |
ϕ \phi ϕ | 变分参数(近似后验分布的参数) | p θ ( x , z ) p_{\theta}(x,z) pθ(x,z) | 生成数据 x x x 和隐变量 z z z 的联合概率分布 |
θ \theta θ | 生成模型的参数 | p θ ( x ∣ z ) p_{\theta}(x |z) pθ(x∣z) | 给定隐变量 z z z 后数据 x x x 的条件概率分布 |
在变分自编码器(VAE)的推导过程中,我们使用变分推断来近似真实后验分布,并且通过优化目标函数来学习生成模型的参数。
推导过程
-
目标:最大化数据的对数似然
变分自编码器的目标是最大化观测数据 x x x 的对数似然,即:
log p θ ( x ) (1) \log p_{\theta}(x) \tag{1} logpθ(x)(1)
由于计算 p θ ( x ) p_{\theta}(x) pθ(x) 直接计算非常困难,我们通过引入隐变量 z z z 和变分推断来简化。 -
引入隐变量
根据隐变量的定义,我们可以将对数似然分解为:
log p θ ( x ) = log ∫ z p θ ( x , z ) d z (2) \log p_{\theta}(x) = \log \int_{z} p_{\theta}(x, z) \, dz \tag{2} logpθ(x)=log∫zpθ(x,z)dz(2) -
应用变分下界(ELBO)
为了优化这个目标,我们引入变分分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x) 来近似真实后验分布 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x)。通过引入变分分布,我们可以使用变分下界(Evidence Lower Bound, ELBO)来逼近对数似然。ELBO 的推导如下:
log p θ ( x ) = log ∫ z p θ ( x , z ) d z (3) \log p_{\theta}(x) = \log \int_{z} p_{\theta}(x, z)dz \tag{3} logpθ(x)=log∫zpθ(x,z)dz(3)
引入变分分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x):
log p θ ( x ) = log ∫ z q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) p θ ( x , z ) d z (4) \log p_{\theta}(x) = \log \int_{z} \frac{q_{\phi}(z|x)}{q_{\phi}(z|x)} p_{\theta}(x, z) \, dz \tag{4} logpθ(x)=log∫zqϕ(z∣x)qϕ(z∣x)pθ(x,z)dz(4)使用对数的变换性质:
log p θ ( x ) = log [ ∫ z q ϕ ( z ∣ x ) p θ ( x , z ) q ϕ ( z ∣ x ) d z ] \log p_{\theta}(x) = \log \left[ \int_{z} q_{\phi}(z|x) \frac{p_{\theta}(x, z)}{q_{\phi}(z|x)} \, dz \right] logpθ(x)=log[∫zqϕ(z∣x)qϕ(z∣x)pθ(x,z)dz]应用对数-期望的变换:
log p θ ( x ) ≥ E q ϕ ( z ∣ x ) [ log p θ ( x , z ) − log q ϕ ( z ∣ x ) ] (5) \log p_{\theta}(x) \geq \mathbb{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta}(x, z) - \log q_{\phi}(z|x) \right] \tag{5} logpθ(x)≥Eqϕ(z∣x)[logpθ(x,z)−logqϕ(z∣x)](5)这个期望值被称为变分下界(ELBO):
ELBO = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) − log q ϕ ( z ∣ x ) ] (6) \text{ELBO} = \mathbb{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta}(x, z) - \log q_{\phi}(z|x) \right] \tag{6} ELBO=Eqϕ(z∣x)[logpθ(x,z)−logqϕ(z∣x)](6)
其中:
ELBO = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − KL ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) (7) \text{ELBO} = \mathbb{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta}(x|z) \right] - \text{KL}(q_{\phi}(z|x) \, || \, p_{\theta}(z)) \tag{7} ELBO=Eqϕ(z∣x)[logpθ(x∣z)]−KL(qϕ(z∣x)∣∣pθ(z))(7)
这里,KL 散度(Kullback-Leibler divergence)是:
KL ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) = ∫ z q ϕ ( z ∣ x ) log q ϕ ( z ∣ x ) p θ ( z ) d z (8) \text{KL}(q_{\phi}(z|x) \, || \, p_{\theta}(z)) = \int_{z} q_{\phi}(z|x) \log \frac{q_{\phi}(z|x)}{p_{\theta}(z)} \, dz \tag{8} KL(qϕ(z∣x)∣∣pθ(z))=∫zqϕ(z∣x)logpθ(z)qϕ(z∣x)dz(8) -
优化目标
最终的目标是最大化 ELBO,从而间接最大化对数似然。通过优化变分参数 ϕ \phi ϕ 和生成参数 θ \theta θ,使得 ELBO 最大化,即可获得最优的模型参数。
-
总结
- 近似后验 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x):通过变分推断来近似真实后验。
- 真实后验 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x):目标是通过优化得到近似。
- 生成模型 p θ ( x , z ) p_{\theta}(x, z) pθ(x,z):模型定义生成数据的过程。
- 条件概率分布 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(x∣z):表示给定隐变量 z z z 的情况下数据 x x x 的分布。
VAE 的关键在于利用 ELBO 来进行模型优化,通过最大化 ELBO 来逼近真实对数似然,从而获得高效的生成模型。
3. 代码实现
以下为基于pytorch的代码实现:
vae.py
import torch
from torch import nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# 编码器所用的结构
self.fc1 = nn.Linear(784, 200)
self.fc2_mu = nn.Linear(200, 20) # 用于生成高斯分布的均值
self.fc2_log_std = nn.Linear(200, 20) # 用于生成高斯分布的方差,且为方便计算默认方差是经过log函数处理的。
# 解码器所用的结构
self.fc3 = nn.Linear(20, 200)
self.fc4 = nn.Linear(200, 784)
def encoder(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc2_mu(h1) # 生成均值
logvar = self.fc2_log_std(h1) # 生成经过log处理的方差
return mu, logvar
def decoder(self, z):
h3 = F.relu(self.fc3(z))
recon = torch.sigmoid(self.fc4(h3)) # 之所以用sigmoid是因为本例用到的图像默认像素值为0-1之间。
return recon
def reparametrize(self, mu, logvar):
var = torch.exp(logvar) # 因为生成的方差是经过log处理的,所以真正要用到方差的时候要再把它经过exp处理一下。
eps = torch.randn_like(var) # 在标准正态分布中采样
z = mu + eps * var # 获得抽取的z
return z
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparametrize(mu, logvar)
recon = self.decoder(z)
return recon, mu, logvar # 返回重构的图,均值,log后的方差
def loss_function(self, recon, x, mu, logvar):
recon_loss = F.mse_loss(recon, x, reduction="sum")
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_loss
return loss
vae_main.py
import torch
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
import datetime
from vae import VAE
if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')
def to_img(x):
x = x.clamp(0, 1) # torch.clamp(input,min,max) 把输入的张量加紧到指定区间内
x = x.view(x.size(0), 1, 28, 28) # batch,channel,w,h
return x
num_epochs = 300
batch_size = 128
img_transform = transforms.Compose([
transforms.ToTensor()
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
dataset = MNIST('./data', transform=img_transform, download=True)
datalodader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
start_time = datetime.datetime.now()
model = VAE()
if torch.cuda.is_available():
print('cuda is ok!')
model = model.to('cuda')
else:
print('cuda is no!')
# loss_function = VAE.loss_function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, data in enumerate(datalodader):
img, _ = data
img = img.view(img.size(0), -1) # 把图像拉平
img = Variable(
img) # tensor不能求导,variable能(其包含三个参数,data:存tensor数据,grad:保留data的梯度,grad_fn:指向function对象,用于反向传播的梯度计算)但我印象中好像tensor可以求梯度 见13讲
img = (img.cuda() if torch.cuda.is_available() else img)
optimizer.zero_grad()
recon_batch, mu, logvar = model(img)
# 计算损失函数
recon_loss = F.mse_loss(recon_batch, img, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
# loss = VAE.loss_function(recon_batch, img, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
# if batch_idx % 100 == 0:
# end_time = datetime.datetime.now()
# print('Train Epoch: {} [{}/{}({:.0f}%] Loss:{:.6f} time:{:.2f}s'.format(
# epoch,
# batch_idx * len(img),
# len(datalodader.dataset),
# (batch_idx * len(img)) / (len(datalodader.dataset)),
# loss.item() / len(img),
# (end_time - start_time).seconds
# ))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(datalodader.dataset)
))
if epoch % 2 == 0:
# 生成图像
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
z = torch.randn(batch_size, 20).to(device)
out = model.decoder(z).view(-1, 1, 28, 28)
# print(out)
# print(type(out))
save_image(out, './vae_img/sample-{}.png'.format(epoch))
# 重构图像
# recon_batch = model.encoder(img).to(device)
save = to_img(recon_batch.cpu().data)
# print(save)
# print(type(save))
save_image(save, './vae_img/image_{}.png'.format(epoch))
torch.save(model.state_dict(), './vae_weight/vae.pth')
实现结果如下:
image.png(原生) 对比 sample.png(生成)