深入解析与实现:变分自编码器(VAE)完整代码详解

VAE理论上一篇已经详细讲完了,虽然VAE已经是过去的东西了,但是它对后面强大的生成模型是很有指导意义的。接下来,我们简单实现一下其代码吧。

1 VAE在minist数据集上的实现

完整的代码如下,没有什么特别好讲的。

import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

""" 就用线性层构造最简单的vae吧"""


class VAE(nn.Module):
    def __init__(self, image_size=28*28, hidden1=400, hidden2=100, latent_dims=40):
        super().__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(image_size, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, hidden2),
            nn.ReLU(),
        )
        self.mu = nn.Sequential(
            nn.Linear(hidden2, latent_dims),
        )

        self.logvar = nn.Sequential(
            nn.Linear(hidden2, latent_dims),
        )   # 由于方差是非负的,因此预测方差对数

        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dims, hidden2),
            nn.ReLU(),
            nn.Linear(hidden2, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, image_size),
            nn.Tanh()
        )

    # 重参数,为了可以反向传播
    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size(), device=mu.device) * std + mu
        return z

    def forward(self, x):
        en = self.encoder(x)
        mu = self.mu(en)
        logvar = self.logvar(en)
        z = self.reparametrization(mu, logvar)

        return self.decoder(z), mu, logvar


def loss_function(fake_imgs, real_imgs, mu, logvar):

    kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)
    reconstruction = ((real_imgs - fake_imgs)**2).sum()

    return kl, reconstruction


def train(num_epoch):

    write_fake = SummaryWriter(f'logs/fake')

    device = torch.device("cuda:0")

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)

    vae = VAE().to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)

    vae.train()
    step = 0
    for epoch in range(num_epoch):
        for batch_indx, (inputs, _) in enumerate(trainloader):

            inputs = inputs.to(device)

            real_imgs = torch.flatten(inputs, start_dim=1)

            fake_imgs, mu, logvar = vae(real_imgs)

            loss_kl, loss_re = loss_function(fake_imgs, real_imgs, mu, logvar)

            loss_all = loss_kl + loss_re

            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()

            print(f"epoch:{epoch}, loss kl:{loss_kl.item()}, loss re:{loss_re.item()}, loss all:{loss_all.item()}")
            if batch_indx == 0:
                with torch.no_grad():
                    x = torch.randn((32, 40)).to(device)
                    fake = vae.decoder(x).reshape(-1, 1, 28, 28)
                    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)

                    write_fake.add_image(
                        "Mnist Fake Image", img_grid_fake, global_step=step
                    )
                    step += 1


if __name__ == "__main__":
    summary(VAE(), input_size=(1, 784))
    train(1000)

模型结构打印如下:
VAE [1, 784] –
├─Sequential: 1-1 [1, 100] –
│ └─Linear: 2-1 [1, 400] 314,000
│ └─ReLU: 2-2 [1, 400] –
│ └─Linear: 2-3 [1, 100] 40,100
│ └─ReLU: 2-4 [1, 100] –
├─Sequential: 1-2 [1, 40] –
│ └─Linear: 2-5 [1, 40] 4,040
├─Sequential: 1-3 [1, 40] –
│ └─Linear: 2-6 [1, 40] 4,040
├─Sequential: 1-4 [1, 784] –
│ └─Linear: 2-7 [1, 100] 4,100
│ └─ReLU: 2-8 [1, 100] –
│ └─Linear: 2-9 [1, 400] 40,400
│ └─ReLU: 2-10 [1, 400] –
│ └─Linear: 2-11 [1, 784] 314,384
│ └─Tanh: 2-12 [1, 784] –

训练结果,从结果上来看,是不如GAN的,主要原因在于其在KL散度和重建损失之间很难做到平衡,所以很难训练得好,当然原因是多方面的。
在这里插入图片描述

2 VAE的缺陷

变分自编码器(VAE, Variational Autoencoder)作为一种强大的深度学习模型,在生成建模领域有着广泛的应用,但它也存在一些缺陷,主要包括:

  • 生成样本质量:与生成对抗网络(GANs)相比,VAE生成的样本可能显得较为模糊或缺乏清晰度。尽管VAE能够生成连续且有结构的潜在空间,其生成的样本在某些情况下可能不够真实或细节不够丰富。

  • 潜在空间的连续性问题:虽然VAE设计用于学习连续的潜在空间,以允许插值和生成流畅的变化序列,但在实践中,这种连续性可能不如理论中那样完美。潜在空间中可能会出现空洞或不连贯区域,影响样本生成的质量和连续性变换的效果。

  • KL散度的平衡问题:VAE通过在其损失函数中加入KL散度项来约束潜在变量的分布,以确保它接近先验分布(通常是标准正态分布)。然而,KL散度的权重难以选择,如果设置不当,可能导致模型过分关注重构损失而忽视了潜在空间的平滑性和多样性,或者相反。

  • 训练难度与稳定性:VAE的训练过程比一些其他模型更为复杂,涉及到优化 Evidence Lower Bound (ELBO),这可能导致训练过程较为不稳定,需要更多的计算资源和更长的训练时间。特别是优化过程中对似然的近似以及对数似然的下界处理增加了训练的复杂度。

  • 表达能力与模型容量:由于VAE的编码器和解码器结构相对简单(通常为全连接层或简单的卷积层),在处理高度复杂的高维数据时,其表达能力可能受限,影响生成样本的质量和多样性。

这些缺陷提示研究者和实践者在使用VAE时需要仔细调整模型架构、损失函数的平衡以及训练策略,以最大化其生成能力和实用性。

  • 25
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

idealmu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值