pytorch生成式对抗网络GAN【一】:GAN生成MNIST手写体

pytorch生成式对抗网络GAN【一】:GAN生成MNIST手写体

生成式对抗网络是使用两个网络进行对抗式训练,一个网络叫生成器(Generator),另一个叫鉴别器(Discriminator)。GAN在2014年被提出,引起了很多研究者的兴趣,它主要是在结构上不同于其他神经网络。

1、GAN的概念

先说GAN中的鉴别器(Discriminator),把鉴别器记作D。鉴别器就如同常见的分类器一样,如果网络输入一幅猫的图像,一个分类器的目标常常是分辨这幅图像是不是猫;而鉴别器D则是要分辨这幅图像是“真猫”还是由生成器所生成的“假猫”。

假设生成器在未经训练时,只能生成一些杂乱无章的图像,这时鉴别器可以轻易地分辨“真猫”和“假猫”,而当生成器经过了足够多的训练之后,已经达到了相当高明的水平,几乎做到以假乱真,这时候鉴别器将不容易分辨出真猫和假猫, 此时鉴别器的正确率大概在50%此时我们的目的已经达到了,因为生成器已经达到了以假乱真的效果,且鉴别器也可以提取出比较完善的特征了。

再说生成器(Generator),把生成器记作G。生成器G的任务与一般的网络不大一样。我们考虑一个网络的结构是:输入→网络→输出。那么生成器G的输入往往是一组随机数,而它的输出却是一副图像。当然在未经训练时,生成器是没有任何以假乱真的能力的。为了提高自己的本领,它将输出的图像送入鉴别器中进行分类,理想的输出值是1,代表它骗过了鉴别器D,如果输出是0则代表它没能骗过鉴别器,以此作为生成器的损失函数。

2、GAN的训练过程

由于GAN中存在两个网络,因此对这两个网络训练的过程需要设计一下。

花开两朵,各表一枝,先说生成器G。生成器的训练相对简单,因为生成器的训练与数据集无关。生成器G的训练步骤是:

1.生成一组随机数z
2.将z送入生成器G中生成一幅图片img
3.将图片img送入鉴别器D中得到D的输出
4.输出与1做差得到误差,反向传播

鉴别器D的训练稍微复杂,在一次训练中,鉴别器需要接收两幅图像并进行两次前向传播。第一次前向传播在训练生成器G时已经进行过,生成器G使用D的输出与1做差,代表生成器期望输出为1;而鉴别器则使用该输出与0做差,代表鉴别器期望输出为0,将该误差记作fake-loss
第二次接收来自训练数据集的图像,label=1,即 使用“真”图像训练,将输出与label=1做差,该误差记作real-loss。

因此鉴别器的总损失是上述两个损失的平均值。

3、GAN代码:

拆解了一份git上的代码:
导入包:

import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

模型参数:

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=4000, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

生成一个存放训练过程图像的文件夹

os.makedirs("images", exist_ok=True)

生成的图像形状以及GPU训练参数

img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False

生成器G定义

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

鉴别器D定义

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

损失函数选择BCELoss

# Loss function
adversarial_loss = torch.nn.BCELoss()

初始化生成器对象与鉴别器对象

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

设置GPU训练

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

用pytorch自带的数据包

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

设置优化器

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

训练过程

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        # Generate a batch of images
        gen_imgs = generator(z)
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
    )

保存模型

PATH1= 'generator.pkl'
PATH2= 'discriminator.pkl'
torch.save(generator,PATH1)
torch.save(discriminator,PATH2)

4、训练过程展示

未训练时生成的图像:
未训练时生成的图像
训练36000个batch后生成的图像:
36000
训练108000个batch后生成的图像:
108000
训练184000个batch后生成的图像:
184000
可见在训练过程中GAN已经可以生成一定形状的图案了。

5、小结

  1. 尽管GAN所生成的图像还显得很简单,但它是一个很好的开始,在它之后各种不同结构的对抗网络层出不穷,效果也越来越好。
  2. 在GAN中,生成器和鉴别器都使用了线性结构,其主干是由Linear层构成的,在后续的网络中(如DCGAN),使用了卷积层来提升其效果。
  3. GAN所生成的手写体是随机的,在某种程度上希望它能够有指向性地生成指定数字,这就是条件式GAN了。
  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值