论文传送门:https://arxiv.org/pdf/1312.6114.pdf
参考代码:GitHub - AntixK/PyTorch-VAE: A Collection of Variational Autoencoders (VAE) in PyTorch.
VAE的目的:构建一个解码器Decoder,通过输入从标准正态分布中采样得到的采样变量X,得到生成样本Y,使Y的分布与输入样本X的分布尽可能接近,从而完成图像生成任务。
VAE的模型结构:编码器Encoder+解码器Decoder,输入样本X经过编码器Encoder输出分布的均值和方差(对数),从该分布中采样得到采样变量X,采样变量X经过解码器Decoder输出生成样本Y。
VAE的方法:通过构建损失函数:
①使得生成样本Y接近输入样本X;
②使得编码器Encoder的输出分布接近标准正态分布,使得分布方差不为0,即采样变量X具有随机性,保证模型的生成能力。
VAE的损失函数:Loss = recons_loss + w * kld_loss
recons_loss描述生成样本Y与输入样本X之间的距离,使用MSE计算;
kld_loss描述解码器Encoder输出分布与标准正态分布之间的距离,使用KL散度计算,化简过程如下图;
w为kld_loss项系数。实现过程见代码82-85行。
重参数技巧:直接从编码器Encoder输出的分布中采样难以实现,但我们知道其均值mu和标准差std,于是我们从标准正态分布中采样得到Z',Z = mu + Z' * std计算得到Z,等价于从输出分布中进行采样得到Z。实现过程见代码63-66行。
import os
import torch