对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)

5 篇文章 0 订阅
3 篇文章 0 订阅

在上一篇博文也就是对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一)中,有关github的代码、注释和计算流程图已经贴出,但上述代码适用于图像识别领域的“Hello World!”——mnist数据集,后来我根据自己实验的需要对代码进行了一些改动:

加载一些必要的库

import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
# from tensorflow.examples.tutorials.mnist import input_data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import scipy.io as sio

根据自己制作的图像列表的txt文档加载自己的数据集

# 根据自己制作的图像列表的txt文档加载自己的数据集
def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        print(txt)
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split('\t')
            imgs.append((words[0], words[1]))
        print(imgs)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)

        if self.transform is not None:
            img = self.transform(img)
        else:
            img = Tensor.from_numpy(img)
        return img, label

    def __len__(self):
        return len(self.imgs)






##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

定义训练迭代器和测试迭代器

transform = transforms.Compose([transforms.Scale((150, 150)), transforms.ToTensor()])   # 转换为张量
train_txt_path = '.txt'  # 自己的数据集的位置列表txt
trainset = MyDataset(txt=train_txt_path, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=mb_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1, shuffle=False)








##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

原github代码部分

再次附上github代码链接:https://github.com/wiseodd/generative-models

def log(x):
    return torch.log(x + 1e-8)

# Encoder: q(z|x,eps)   # 编码器
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim + eps_dim, h_dim),   # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)              # 一个全连接层
)

# Decoder: p(x|z)       # 解码器
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),             # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),             # 一个全连接层
    torch.nn.Sigmoid()
)

# Discriminator: T(X, z)   # 判别器
T = torch.nn.Sequential(
    torch.nn.Linear(X_dim + z_dim, h_dim),     # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1)                  # 一个全连接层   # 输出为一维,即一个数
)

Q.cuda()
P.cuda()
T.cuda()

def reset_grad():          # 重置梯度为0
    Q.zero_grad()
    P.zero_grad()
    T.zero_grad()

Q_solver = optim.Adam(Q.parameters(), lr=lr)    # 三个模块的优化求解器
P_solver = optim.Adam(P.parameters(), lr=lr)
T_solver = optim.Adam(T.parameters(), lr=lr)

开始训练迭代

for it in range(1000000):    # 开始迭代
    print(it)
    # X = sample_X(mb_size)   # 输入为从训练集中采样并进行类型转换后的数据
    for i, (X, _) in enumerate(train_loader):
        X = X.view(-1, 150 * 150 * 3)
        X = Variable(X)
        eps = Variable(torch.randn(mb_size, eps_dim))
        z = Variable(torch.randn(mb_size, z_dim))   # 由标准正态分布(均值为0,方差为1)中随机采样

        # Optimize VAE   # 优化变分自编码器
        # z_sample = Q(torch.cat([X, eps], 1))   # 按列拼接,需要维度一致方能行对齐
        z_sample = Q(torch.cat([X, eps], 1).cuda())   # 按列拼接,需要维度一致方能行对齐
        # X_sample = P(z_sample)
        X_sample = P(z_sample.cuda())
        # T_sample = T(torch.cat([X, z_sample], 1))
        T_sample = T(torch.cat([X, z_sample.cpu()], 1).cuda())

        disc = torch.mean(-T_sample)    # 判别器输出的负数的均值
        loglike = -nn.binary_cross_entropy(X_sample, X.cuda(), size_average=False) / mb_size
        # 交叉熵, 最小化交叉熵损失函数等价于最大化对数似然, 让重构图像尽可能接近原始输入图像

        elbo = -(disc + loglike)   # 证据下界, 常用在变分推断中

        elbo.backward()    # 证据下界反向传播,优化编码器与解码器
        Q_solver.step()
        P_solver.step()
        reset_grad()     # 重置梯度为0

        # Discriminator T(X, z)     # 对于判别器,优化判别器
        # z_sample = Q(torch.cat([X, eps], 1))  # z_sample是输入经过编码器后的输出
        z_sample = Q(torch.cat([X, eps], 1).cuda())  # z_sample是输入经过编码器后的输出
        T_q = nn.sigmoid(T(torch.cat([X, z_sample.cpu()], 1).cuda()))
        T_prior = nn.sigmoid(T(torch.cat([X, z], 1).cuda()))

        T_loss = -torch.mean(log(T_q) + log(1. - T_prior))

        T_loss.backward()
        T_solver.step()
        reset_grad()     # 重置梯度为0

        if (it + 1) % 10 == 0:
            print('Iter-{}; ELBO: {:.4}; T_loss: {:.4}'
                  .format(it, -elbo.data[0], -T_loss.data[0]))

        # Print and plot every now and then
    if (it + 1) % 10 == 0:

        for k, (X, _) in enumerate(test_loader):

            if k < 3:

                X = X.view(-1, 150 * 150 * 3)
                X = Variable(X)
                eps = Variable(torch.randn(1, eps_dim))
                z = Variable(torch.randn(1, z_dim))  # 由标准正态分布(均值为0,方差为1)中随机采样

                z_sample = Q(torch.cat([X, eps], 1).cuda())  # 按列拼接,需要维度一致方能行对齐
                X_random = P(z.cuda()).data.cpu().numpy()
                reconst = P(z_sample).data.cpu().numpy()  # 原始输入输入解码器后的输出, 取前16个作为示范
                X = X.numpy()

                reconst1 = reconst.reshape(3, 150 * 150)
                reconst2 = reconst1[0, :]
                reconst3 = np.zeros(((3, 150, 150)))
                reconst3 = np.array(reconst3)
                reconst3[0, :, :] = reconst1[0, :].reshape(150, 150)
                reconst3[1, :, :] = reconst1[1, :].reshape(150, 150)
                reconst3[2, :, :] = reconst1[2, :].reshape(150, 150)
                reconst = reconst3
                # save_image(torch.from_numpy(sample), 'try/reconst_iter_' + str(it) + '_' + str(k) + '.png')

                x1 = X.reshape(3, 150 * 150)
                x2 = x1[0, :]
                x3 = np.zeros(((3, 150, 150)))
                x3 = np.array(x3)
                x3[0, :, :] = x1[0, :].reshape(150, 150)
                x3[1, :, :] = x1[1, :].reshape(150, 150)
                x3[2, :, :] = x1[2, :].reshape(150, 150)
                xx = x3
                # save_image(torch.from_numpy(xx), 'try/input_' + str(it) + '_' + str(k) + '.png')

                image_show = np.concatenate((xx, reconst), axis=2)
                image_show = image_show[np.newaxis, :, :, :]
                save_image(torch.from_numpy(image_show), 'try/compare_' + str(it) + '_' + str(k) + '.png')

                random1 = X_random.reshape(3, 150 * 150)
                random2 = random1[0, :]
                random3 = np.zeros(((3, 150, 150)))
                random3 = np.array(random3)
                random3[0, :, :] = random1[0, :].reshape(150, 150)
                random3[1, :, :] = random1[1, :].reshape(150, 150)
                random3[2, :, :] = random1[2, :].reshape(150, 150)
                random_image = random3

                save_image(torch.from_numpy(random_image), 'try/random_' + str(it) + '_' + str(k) + '.png')






##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

代码运行结果

代码运行结果如下所示:
compare_0_0.png:
在这里插入图片描述
compare_0_1.png:
在这里插入图片描述
compare_0_2.png:
在这里插入图片描述
compare_9_0.png:
在这里插入图片描述
compare_9_1.png:
在这里插入图片描述
compare_9_2.png:
在这里插入图片描述
compare_19_0.png:
在这里插入图片描述
compare_19_1.png:
在这里插入图片描述
compare_19_2.png:
在这里插入图片描述
。。。。。。
random_0_0.png:
在这里插入图片描述
random_0_1.png:
在这里插入图片描述
random_0_2.png:
在这里插入图片描述
random_9_0.png:
在这里插入图片描述
random_9_1.png:
在这里插入图片描述
random_9_2.png:
在这里插入图片描述
random_19_0.png:
在这里插入图片描述
random_19_1.png:
在这里插入图片描述
random_19_2.png:
在这里插入图片描述
。。。。。。
可以看出,随着迭代次数增加,图片生成质量也会越高。

一点小问题

代码中对隐含层的z_sample位置再次进行了高斯随机采样以生成新的人脸图像,对于对抗变分自编码器来说这是否合理?(AVB的隐含层没有显式的概率分布,其为一个黑箱模型)对于AVB有没有更好的生成新的人脸图像的采样方法呢?
这篇对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (二)就先写到这里,若有疏漏、不恰当或者错误的地方还请及时指出。另外,代码还需进一步的优化,若你有更好的修改方式或想法,请不吝赐教。
这里附上上一篇博文的链接:对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一):https://blog.csdn.net/S20144144/article/details/99467235

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值