变分自动编码器Variational Auto-Encoding(VAE)基本原理和理解,附上python代码(包含中文注释)

  

VAE原理

  我们知道,对于生成模型而言,主流的理论模型可以分为隐马尔可夫模型HMM、朴素贝叶斯模型NB和高斯混合模型GMM,而VAE的理论基础就是高斯混合模型。

       什么是高斯混合模型呢?就是说,任何一个数据的分布,都可以看作是若干高斯分布的叠加。

如图所示,上面黑色线即为高斯混合分布的例子,如果把该条线拆分可获得若干条浅蓝色曲线(高斯分布)的叠加。有趣的是,当拆分数量达到512时,其叠加的分布相对于原始分布而言误差就非常小了。

       然后,我们可以利用这一理论模型去考虑如何给数据进行编码。一种最直接的思路是,直接用每一组高斯分布的参数作为一个编码值实现编码。

这里m表示每一条高斯分布的曲线,每采样一个m,则会获得对应的高斯分布N\left ( \mu^{m}, \Sigma^{m} \right ),对于这条曲线的分布函数P(x)而言,它可以表示为以下公式:

p\left ( x \right ) = \sum_{m} p\left ( m \right ) p\left ( x|m \right )

上述的这种编码方式是非常简单粗暴的,它对应的是我们之前提到的离散的、有大量失真区域的编码方式。于是我们需要对目前的编码方式进行改进,使得它成为连续有效的编码。 

 现在我们的编码规则换成连续的变量z,同时,我们假设z服从正态分布N\left ( 0,1 \right )(这里可以假设任意分布,需根据数据来定

对于每一个分布函数z,都会有对应方差\mu和均值\sigma,两者决定了高斯分布的形状和范围。然后,累加所有可能的z就获得了连续状态下的P(x)。

P(x) = \int_{z} P(z)P(x|z) d z

 这里P(z)是已知的关于z的高斯分布,而P(x|z)是未知的每个z的高斯分布⭐⭐⭐

实际上我们求解P(x|z)就是求解关于z分布所对应的所有的 \mu\sigma。这通常是一个很难求解的过程,因此可以使用神经网络来建模。

———————————————————⭐⭐⭐————————————————————

  构建一个编码器Encoder如下,它可以求出一个数据分布q(x|z),来表示关于x的数据分布,用于推进P(x|z)的求解:

然后,构建一个解码器Decoder如下 ,它可以求解 \mu\sigma两个参数,即等价于求解P(x|z):


关于Encoder这部分内容就比较偏数学了,实际上就是推导公式引入一个新的变量 q(x|zKL\left ( p_{1} (x)||P_{2} (x) \right )=\int _{x} p_{1} (x)log\frac{ p_{1} (x)}{ p_{2} (x)}dx),通过在给定q(x|z)然后优化P(x|z)使其输出值足够高。这样做可以使算法获得更好的性质,更易于求解,同时也导出了VAE的损失函数。总的流程图如下:

———————————————————⭐⭐⭐————————————————————

VAE损失函数

下面,对于要实现VAE模型来说,损失函数是关键,这里简单推导一下损失函数。

损失函数包含重构损失+KL散度两部分

1. KL散度

前面我们引入了q\left ( z|x \right )P\left ( x|z \right ),前者表示VAE模拟出的数据分布情况,后者表示数据真实的分布,接下来,根据两者之间的差异性,可以导出一个KL散度公式。

KL散度标准公式定义如下

KL(p_{1}(x)||p_{2}(x)) = \int _{x}p_{1}(x)log\frac{p_{1}(x)}{p_{2}(x)}dx

针对两种分布情况,p_{1}=N_{1}({\mu_{1}, \delta_{1}}), p_{2}=N_{2}({\mu_{2}, \delta_{2}}),可以简化上述公式如下(复制来的,可以忽略)

 

 以上是KL散度的标准公式,可以不看。

前面我们已经讨论了VAE是关于高斯混和模型的讨论。因此,我们定义p_{2}=N_{2}(\mu_{2}, \delta_{2})是从标准正态分布中采样的,即p_{2}=N_{2}(0, 1),那么,

KL(p_{1}||p_{2})) = -\frac{1}{2}\times [2log\delta_{1}+1-\delta_{1}^{2}-\mu_{1}^{2}]

不难看出,p1就是原始数据的真实分布,p2就是VAE模拟出的数据分布。

2. 重构损失

重构损失就是重构结果和原始数据之间的误差,可以用BCE、MSE、L1等等多种损失函数。

总结一下,总的损失函数定义如下:

Loss=BCE(x,\widehat{x})+KL(p_{1}||p_{2}) 

———————————————————⭐⭐⭐————————————————————

VAE实现

附上python代码(包含中文注释)

import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os

# 图像转为二维可视化
def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

# 定义VAE类
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        INOUT_num = 784                                          # 输入(输出)大小
        hidden_num1 = 400                                        # 隐藏层大小
        hidden_num2 = 20                                         # 隐变量大小
        self.fc1 = nn.Linear(INOUT_num, hidden_num1)             # (编码) 全连接层
        self.fc21 = nn.Linear(hidden_num1, hidden_num2)          # (编码) 计算 mean
        self.fc22 = nn.Linear(hidden_num1, hidden_num2)          # (编码) 计算 logvar
        self.fc3 = nn.Linear(hidden_num2, hidden_num1)           # (解码) 隐藏层
        self.fc4 = nn.Linear(hidden_num1, INOUT_num)             # (解码) 输出层

    def encode(self, x):
        # 全连接层
        hidden1 = self.fc1(x)
        # relu层
        h1 = F.relu(hidden1)
        # 计算mean
        mu = self.fc21(h1)
        # 计算var
        logvar = self.fc22(h1)
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()                            # mul是乘法的意思,然后exp_是求e的次方并修改原数值  所有带"—"都是inplace的 意思就是操作后 原数也会改动

        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()  # 在cuda中生成一个std.size()的张量,标准正态分布采样,类型为FloatTensor
        else:
            eps = torch.FloatTensor(std.size()).normal_()       # 生成一个std.size()的张量,正态分布,类型为FloatTensor
        eps = Variable(eps)                                     # Variable是torch.autograd中很重要的类。它用来包装Tensor,将Tensor转换为Variable之后,可以装载梯度信息。
        repar = eps.mul(std).add_(mu)
        return repar

    def decode(self, z):
        # 隐藏层
        hidden2 = self.fc3(z)
        # relu层
        h3 = F.relu(hidden2)
        # 隐藏层
        hidden3 = self.fc4(h3)
        # sigmoid层
        output = F.sigmoid(hidden3)
        return output

    def forward(self, x):
        mu, logvar = self.encode(x)           # 编码
        z = self.reparametrize(mu, logvar)    # 重新参数化成正态分布
        decodez = self.decode(z)              # 解码
        return decodez, mu, logvar

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # MSE loss
    KLD = -0.5 * torch.sum(logvar + 1 - mu.pow() - logvar.exp())
    # KL divergence
    return BCE + KLD

if __name__ == '__main__':
    #  创建路径
    if not os.path.exists('./vae_img'):
        os.mkdir('./vae_img')
    #  VAE参数设置
    num_epochs = 30
    batch_size = 128
    learning_rate = 1e-3
    # 定义数据格式
    img_transform = transforms.Compose([
        transforms.ToTensor()   # 将原始的PILImage格式或者numpy.array格式的数据格式化为可被pytorch快速处理的张量类型。
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])  # orchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起:
    # 加载数据
    dataset = MNIST('./data', transform=img_transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = VAE() # 实例化VAE
    if torch.cuda.is_available():
        model.cuda()
    
    reconstruction_function = nn.MSELoss(size_average=False)   # 定义损失函数,可修改其他

    optimizer = optim.Adam(model.parameters(), lr = learning_rate)
    
#开始训练
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, data in enumerate(dataloader):
            img, _ = data
            img = img.view(img.size(0), -1)
            img = Variable(img)
            if torch.cuda.is_available():
                img = img.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(img)
            loss = loss_function(recon_batch, img, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch,
                    batch_idx * len(img),
                    len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                    loss.item() / len(img)))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(dataloader.dataset)))
        if epoch % 10 == 0:
            save = to_img(recon_batch.cpu().data)
            save_image(save, './vae_img/image_{}.png'.format(epoch))

    torch.save(model.state_dict(), './vae.pth')

Encoder这部分的内容是VAE最为巧妙的地方,是灵魂所在,感兴趣的小伙伴可以深入探究下,附上相关链接:【学习笔记】生成模型——变分自编码器。以上内容如有错误,欢迎指出积极讨论!

  • 12
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值