pytorch实现全连接网络的vae

网络全用全连接层nn.Linear()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import os
import math


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.E_fc1 = nn.Linear(28*28, 256)

        self.fc_mu = nn.Linear(256, 8)
        self.fc_var = nn.Linear(256, 8)

        self.z_fc = nn.Linear(8, 256)

        self.D_fc1 = nn.Linear(256, 28*28)

    def repara(self, mu, var):

        var = torch.exp(0.5*var)
        epsilon = torch.randn_like(var)
        z = mu + epsilon * var

        return z

    def decoder(self, z):
        out = F.relu(self.z_fc(z))
        out = torch.sigmoid(self.D_fc1(out))

        return out

    def forward(self, x):
        out = F.relu(self.E_fc1(x.view(-1, 28*28)))

        mu = self.fc_mu(out)
        var = self.fc_var(out)

        z = self.repara(mu, var)

        out = self.decoder(z)

        return out, mu, var


if __name__ == '__main__':

    epoch = 100
    batch_size = 64

    model = VAE().cuda()

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)
    test_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=False,
                                               transform=transforms.ToTensor(),
                                               download=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                               batch_size=batch_size,
                                               shuffle=False)
    iter = 0
    model.train()
    for j in range(epoch):
        for _, data in enumerate(train_loader):

            img, labels = data
            img = img.cuda()

            out, mu, var = model(img)

            rec_loss = F.binary_cross_entropy(out, img.view(-1, 28*28), reduction='sum')
            kl_loss = -0.5 * torch.sum(1 + var - mu.pow(2) - var.exp())

            loss = rec_loss + kl_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch {} iter {}: kl_loss={:.4f}, rec_loss={:.4f}, loss={:.4f}".format(j, iter, kl_loss, rec_loss, loss))

            if iter % 500 == 0:
                model.eval()
                with torch.no_grad():
                    if not os.path.exists('./vae_val_result'):
                        os.mkdir('./vae_val_result')

                    for i, test_data in enumerate(test_loader):
                        test_img, test_labels = test_data
                        test_img = test_img.view(test_img.size()[0], -1).cuda()

                        test_out, _, _ = model(test_img)
                        if i == 0:
                            test_out = test_out.view(-1, 1, 28, 28)
                            test_img = test_img.view(-1, 1, 28, 28)
                            save_data = torch.cat([test_out[9:17], test_img[9:17]])

                            torchvision.utils.save_image(save_data, './vae_val_result/epoch'+str(j)+'iter'+str(iter)+'.jpg')
            iter += 1

    print("Training is over!")

    with torch.no_grad():
        z = torch.randn(64, 8).cuda()
        sample_data = model.decoder(z)
        torchvision.utils.save_image(sample_data.view(-1, 1, 28, 28), './vae_val_result/ sample_img.jpg')
        print("Saving successfully sample_image!")

训练100个epoch的测试结果(上面一行为重构结果,下面一行为ground truth):
在这里插入图片描述
采样结果:
在这里插入图片描述
从结果可以看出全连接网络的VAE重构结果的效果已经比较好了,采样结果一般。若改用卷积网络,生成结果应该会有很大提升。

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值