VAE 代码实现

参考原文:https://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/

本文附带jupyter notebook文件已上传到我的CSDN资源中

1. 导入模型训练相关包

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

2. 配置设备

# # 设备配置
# torch.cuda.set_device(0) # 这句用来设置pytorch在哪块GPU上运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 如果没有文件夹就创建一个文件夹
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

3. 设置超参数、加载dataloader

# 超参数设置
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3
dataset = torchvision.datasets.MNIST(root='../../../data/minist',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# 数据加载器
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../../data/minist\MNIST\raw\train-images-idx3-ubyte.gz



0it [00:00, ?it/s]


Extracting ../../../data/minist\MNIST\raw\train-images-idx3-ubyte.gz to ../../../data/minist\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../../data/minist\MNIST\raw\train-labels-idx1-ubyte.gz



0it [00:00, ?it/s]


Extracting ../../../data/minist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../../data/minist\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../../data/minist\MNIST\raw\t10k-images-idx3-ubyte.gz



0it [00:00, ?it/s]


Extracting ../../../data/minist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../../data/minist\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../../data/minist\MNIST\raw\t10k-labels-idx1-ubyte.gz



0it [00:00, ?it/s]


Extracting ../../../data/minist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../../data/minist\MNIST\raw
Processing...
Done!


c:\users\user\appdata\local\programs\python\python37\lib\site-packages\torchvision\datasets\mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:141.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

查看dataloader中的内容

#导入包
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import numpy as np
x = iter(data_loader).next()[0]
x.shape
torch.Size([128, 1, 28, 28])

可以看到dataloader的一次迭代可以加载出128×1×28×28的图片

128: batch大小

1: 通道数(灰度图都是一个通道)

28×28: 单个通道的图像数据

plt.imshow(x[0][0])
<matplotlib.image.AxesImage at 0x2ed7fdb86c8>


在这里插入图片描述

上图为一张图片所显示的内容

4. 设置模型

# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim) # 均值 向量
        self.fc3 = nn.Linear(h_dim, z_dim) # 保准方差 向量
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    # 编码过程
    def encode(self, x):
        print("1:"+str(x.shape))
        h = F.relu(self.fc1(x))
        print("2:"+str(h.shape))
        return self.fc2(h), self.fc3(h)
    
    # 随机生成隐含向量
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    # 解码过程
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    # 整个前向传播过程:编码-》解码
    def forward(self, x):
        mu, log_var = self.encode(x)
        print("3:"+str(mu.shape))
        print("4:"+str(log_var.shape))
        z = self.reparameterize(mu, log_var)
        print("5:"+str(z.shape))
        x_reconst = self.decode(z)
        print("6:"+str(x_reconst.shape))
        return x_reconst, mu, log_var


上述网络结构图形化如下

在这里插入图片描述


其中红色虚线框中的内容是损失函数的组成部分

5. 开始训练

# 实例化一个模型
model = VAE().to(device)

# 创建优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # 获取样本,并前向传播
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)
        # KL散度的计算可以参考论文或者文章开头的链接
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # 反向传播和优化
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    # 利用训练的模型进行测试
    with torch.no_grad():
        # 随机生成的图像
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # 重构的图像
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
  warnings.warn(warning.format(ret))


Epoch[1/15], Step [100/469], Reconst Loss: 22325.9961, KL Div: 1292.9675
Epoch[1/15], Step [200/469], Reconst Loss: 16562.2441, KL Div: 2101.5405
Epoch[1/15], Step [300/469], Reconst Loss: 15128.4561, KL Div: 2418.6357
Epoch[1/15], Step [400/469], Reconst Loss: 14666.9990, KL Div: 2442.3835
Epoch[2/15], Step [100/469], Reconst Loss: 13904.7002, KL Div: 2920.0483
Epoch[2/15], Step [200/469], Reconst Loss: 12945.5293, KL Div: 2899.4502
Epoch[2/15], Step [300/469], Reconst Loss: 12416.3398, KL Div: 2859.1750
Epoch[2/15], Step [400/469], Reconst Loss: 11977.3125, KL Div: 2835.1426
Epoch[3/15], Step [100/469], Reconst Loss: 12504.1533, KL Div: 3067.2375
Epoch[3/15], Step [200/469], Reconst Loss: 11617.1113, KL Div: 3061.0508
Epoch[3/15], Step [300/469], Reconst Loss: 11711.5244, KL Div: 3130.3713
Epoch[3/15], Step [400/469], Reconst Loss: 11942.1924, KL Div: 3115.7471
Epoch[4/15], Step [100/469], Reconst Loss: 11302.0635, KL Div: 3117.6763
Epoch[4/15], Step [200/469], Reconst Loss: 11396.1738, KL Div: 3202.3250
Epoch[4/15], Step [300/469], Reconst Loss: 11127.0645, KL Div: 3171.7722
Epoch[4/15], Step [400/469], Reconst Loss: 10985.8320, KL Div: 3098.4009
Epoch[5/15], Step [100/469], Reconst Loss: 11460.6963, KL Div: 3230.8091
Epoch[5/15], Step [200/469], Reconst Loss: 10541.7783, KL Div: 3221.3369
Epoch[5/15], Step [300/469], Reconst Loss: 10609.5420, KL Div: 3134.0396
Epoch[5/15], Step [400/469], Reconst Loss: 10746.1963, KL Div: 3186.7300
Epoch[6/15], Step [100/469], Reconst Loss: 10613.0098, KL Div: 3161.1631
Epoch[6/15], Step [200/469], Reconst Loss: 10862.5127, KL Div: 3171.8523
Epoch[6/15], Step [300/469], Reconst Loss: 11125.9102, KL Div: 3209.8787
Epoch[6/15], Step [400/469], Reconst Loss: 10361.1904, KL Div: 3179.6394
Epoch[7/15], Step [100/469], Reconst Loss: 10869.8262, KL Div: 3277.3511
Epoch[7/15], Step [200/469], Reconst Loss: 10583.9775, KL Div: 3272.1274
Epoch[7/15], Step [300/469], Reconst Loss: 9966.8125, KL Div: 3117.8450
Epoch[7/15], Step [400/469], Reconst Loss: 10690.5742, KL Div: 3339.8892
Epoch[8/15], Step [100/469], Reconst Loss: 10644.7383, KL Div: 3299.1499
Epoch[8/15], Step [200/469], Reconst Loss: 10652.6270, KL Div: 3297.8372
Epoch[8/15], Step [300/469], Reconst Loss: 10541.0684, KL Div: 3166.6426
Epoch[8/15], Step [400/469], Reconst Loss: 10794.7314, KL Div: 3329.0159
Epoch[9/15], Step [100/469], Reconst Loss: 10347.5000, KL Div: 3291.0581
Epoch[9/15], Step [200/469], Reconst Loss: 10460.7686, KL Div: 3147.4270
Epoch[9/15], Step [300/469], Reconst Loss: 10217.2275, KL Div: 3206.6414
Epoch[9/15], Step [400/469], Reconst Loss: 10608.9072, KL Div: 3285.1226
Epoch[10/15], Step [100/469], Reconst Loss: 10454.6016, KL Div: 3290.0586
Epoch[10/15], Step [200/469], Reconst Loss: 10632.7822, KL Div: 3259.0110
Epoch[10/15], Step [300/469], Reconst Loss: 10514.3359, KL Div: 3185.3164
Epoch[10/15], Step [400/469], Reconst Loss: 10258.9453, KL Div: 3200.7063
Epoch[11/15], Step [100/469], Reconst Loss: 10047.3574, KL Div: 3214.2043
Epoch[11/15], Step [200/469], Reconst Loss: 9705.0078, KL Div: 3210.4810
Epoch[11/15], Step [300/469], Reconst Loss: 10236.5371, KL Div: 3314.7139
Epoch[11/15], Step [400/469], Reconst Loss: 10746.6348, KL Div: 3258.6812
Epoch[12/15], Step [100/469], Reconst Loss: 9837.2031, KL Div: 3136.6541
Epoch[12/15], Step [200/469], Reconst Loss: 10117.1963, KL Div: 3282.7031
Epoch[12/15], Step [300/469], Reconst Loss: 9952.3184, KL Div: 3148.8638
Epoch[12/15], Step [400/469], Reconst Loss: 10463.5410, KL Div: 3257.8792
Epoch[13/15], Step [100/469], Reconst Loss: 10687.4766, KL Div: 3315.0667
Epoch[13/15], Step [200/469], Reconst Loss: 10573.5977, KL Div: 3253.9087
Epoch[13/15], Step [300/469], Reconst Loss: 10285.8145, KL Div: 3226.7212
Epoch[13/15], Step [400/469], Reconst Loss: 9812.1465, KL Div: 3238.2170
Epoch[14/15], Step [100/469], Reconst Loss: 10094.8643, KL Div: 3275.3123
Epoch[14/15], Step [200/469], Reconst Loss: 10149.8086, KL Div: 3302.6235
Epoch[14/15], Step [300/469], Reconst Loss: 10553.0664, KL Div: 3305.8149
Epoch[14/15], Step [400/469], Reconst Loss: 10361.6904, KL Div: 3249.9197
Epoch[15/15], Step [100/469], Reconst Loss: 10149.0605, KL Div: 3283.0081
Epoch[15/15], Step [200/469], Reconst Loss: 10201.4980, KL Div: 3220.1846
Epoch[15/15], Step [300/469], Reconst Loss: 10114.3887, KL Div: 3159.8972
Epoch[15/15], Step [400/469], Reconst Loss: 10541.4033, KL Div: 3248.5728

reconsPath = './samples/reconst-15.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

在这里插入图片描述

genPath = './samples/sampled-15.png'
Image = mpimg.imread(genPath)
plt.imshow(Image) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

在这里插入图片描述

  • 1
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
下面是一个使用PyTorch实现Variational Autoencoder (VAE)的简单代码示例: ```python import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms # 定义VAE模型 class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) # 均值 self.fc22 = nn.Linear(400, 20) # 方差 self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar # 计算VAE的损失函数 def loss_function(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD # 加载MNIST数据集 train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True) # 初始化VAE模型和优化器 model = VAE() optimizer = optim.Adam(model.parameters(), lr=1e-3) # 训练VAE模型 def train(epoch): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset))) # 使用训练好的VAE生成一些样本 def generate_samples(num_samples): with torch.no_grad(): sample = torch.randn(num_samples, 20) sample = model.decode(sample).cpu() return sample # 开始训练和生成样本 num_epochs = 10 for epoch in range(1, num_epochs + 1): train(epoch) samples = generate_samples(10) ``` 这个代码示例实现了一个简单的VAE模型,使用MNIST数据集进行训练和生成样本。你可以根据实际情况进行修改和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值