pytorch训练GAN的代码(基于MNIST数据集)

论文:Generative Adversarial Networks
作者:Ian J. Goodfellow
年份:2014年

从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简单记录一下,有时间会补充。
更多关于GAN的可以看我另一篇:https://blog.csdn.net/demo_jie/article/details/106724016

直接讲代码实现部分,这个代码是用pytorch训练GAN,基于MNIST数据集
真实图片:
在这里插入图片描述

代码:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('img'):
    os.mkdir('img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)   #输出限制在0,1范围内
    out = out.view(-1, 1, 28, 28)
    return out


# 初始化参数
batch_size = 128
num_epoch = 10
z_dimension = 100
# 对图片进行一些前期处理操作
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# img_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ]
# 数据集下载
mnist = datasets.MNIST(
    root='E:/low-light/deep learning/GAN/data/', train=True, transform=img_transform, download=True)
# 数据集加载
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)


# 判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(nn.Linear(784, 256),
                                 nn.LeakyReLU(0.2),
                                 nn.Linear(256, 256),
                                 nn.LeakyReLU(0.2), nn.Linear(256, 1),
                                 nn.Sigmoid())  # sigmoid激活函数得到一个01之间的概率进行二分类

    def forward(self, x):
        x = self.dis(x)
        return x


# 生成器
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh())  # Tanh激活函数是希望生成的假的图片数据分布能够在-11之间。

    def forward(self, x):
        x = self.gen(x)
        return x


D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# 判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。
# 二进制交叉熵损失和优化器
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# 开始训练
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # ========================================================================训练判别器
        img = img.view(num_img, -1)  # # 将图片展开乘28x28=784
        # real_img = Variable(img).cuda()
        # real_label = Variable(torch.ones(num_img)).cuda()
        # fake_label = Variable(torch.zeros(num_img)).cuda()
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))  # 定义真实label为1
        fake_label = Variable(torch.zeros(num_img))  # 定义假label为1

        # 计算 real_img 的损失
        real_out = D(real_img)  # 将真实的图片放入判别器中
        d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out  # 越接近一越好

        # 计算 fake_img的损失
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # 随机生成一些噪声
        fake_img = G(z)  # 放入生成网络生成一张假的图片
        fake_out = D(fake_img)  ## 判别器判断假的图片
        d_loss_fake = criterion(fake_out, fake_label)  ## 得到假的图片的loss
        fake_scores = fake_out  # 越接近0越好

        # 反向传播和优化
        d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来
        d_optimizer.zero_grad()  # 每次梯度归零
        d_loss.backward()  # 反向传播
        d_optimizer.step()  # 更新参数

        # =====================================================================训练生成器

        # 计算fake_img损失
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # 得到随机噪声
        fake_img = G(z)  # 生成假的图片
        output = D(fake_img)  # 经过判别器得到结果
        g_loss = criterion(output, real_label)  ##得到假的图片与真实图片label的loss

        # 反向传播和优化
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, 'real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')

运行结果:

在这里插入图片描述

这次一共跑了10次,以下是生成的噪声图片,分别是跑了1,3,5,7,9,10次的图片(训练次数太少了,所以效果不明显,可以自己设置训练次数)请添加图片描述

请添加图片描述

请添加图片描述

请添加图片描述

请添加图片描述
请添加图片描述
生成的真实图片:
在这里插入图片描述

  • 3
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值