GAN 生成MNIST数据集

1、GAN是什么

GAN(生成式对抗网络,Generative Adversarial Networks)是一种深度学习模型,模型通过框架中两个模块(生成模型 Generative Model 和判别模型 Discriminative Model)的互相博弈学习,从而产生相当好的输出。原始GAN理论中,只要求G和D能拟合出相应生成和判别的函数即可,而并不要求他们必须都是神经网络,但是我们的实际应用中,一般都是采用深度神经网络作为G和D。

GAN论文:https://arxiv.org/abs/1406.2661

2、GAN的原理

@基本原理

GAN分为一个判别器(Discriminator,简称D)和一个生成器(Generator,简称G),简单的说,G和D就是两个多层感知机或卷积神经网络,它的基本思想,即为G和D的生成博弈过程。

G是一个生成图片的网络,它接收一个随机的噪声z,并且通过这个噪声生成图片,记做G(z)

D是一个进行判别的网络,它可以判别出一张图片是不是真实的。即给D输入真图片,它会将label赋值为1,输入假图片,就将label赋值为0

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D,使D认为自己生成的是真图片;而D的目标就是尽量将G生成的图片和真实的图片区别开来,这样G和D就形成了一个动态的博弈过程。

那么,最后博弈的结果是什么呢?在理想的状态下,G足以生成以假乱真的图片G(z),对于D来说,它难以判定G生成的图片究竟是不是真实的,因此有D(G(z))=0.5

具体的流程如图中所示。

首先有一个一代的G,它生成的是一些很差的图片然后有一个一代的D,它能很准确的把G生成的假图片和真实的图片区分出来,打上标签0。其实这个D就是一个二分类器,对生成的图片输出0,而对真实的图片输出1。

接着经过训练,出现了二代的G,它能生成稍好一点的图片,能让一代的D认为他生成的是真图片。这时也出现了二代的D,它能识别哪些图片是G生成的,哪些是真实的图片。

以此类推,会有三代,四代。。。。n 代的 G(generator) 和D( discriminator),最后 D 无法分辨生成的图片和真实图片,这个网络就拟合了。

这两种网络具体是怎样的呢?

@Discriminator Network

首先要说的是对抗网络,因为这个网络相对较为简单。

对抗网络简单来说就是一个判断真假的判别器,解决的是一个二分类的问题。输入一张真的图片时我们希望它的输出结果是1,输入一张假的图片我们希望它能输出0。这其实和原图片的类别没有什么关系,无论原图片是什么类别的图片,我们都统称它为真图片,label为1;而生成的图片它是假的,label为0

我们对D训练的过程就是希望这个判别器能准确地判别出真的图片和假的图片,对于这个二分类问题可以有很多解决的方法,比如 logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

# 判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
                 nn.Linear(784, 256),
                 nn.LeakyReLU(0.2),
                 nn.Linear(256, 256),                                        nn.LeakyReLU(0.2),
                 nn.Linear(256, 1),
                 # sigmoid激活函数得到一个0到1之间的概率进行二分类
                 nn.Sigmoid())   
 
    def forward(self, x):
        x = self.dis(x)
        return x

@Generative Network

怎样才能生成一张假的图片呢?

首先给出一个简单高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,通过对它进行仿射变换,也就是将 xw+b 映射到一个更高的维度,然后将它重新排列成一个矩形,这样看起来就更加像一张图片,然后经过一系列的卷积、池化、激活操作,最后得到了一个与我们输入图片大小一模一样的噪声矩阵,这就是我们所说的假的图片。这个时候是怎样训练我们的生成器呢?其实通过判别器来得到结果,我们不断增大判别器识别生成图片为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

# 生成网络
class generator(nn.Module):
    def __init__(self, input_size):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            #Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
            nn.Tanh()
        )
 
    def forward(self, x):
        x = self.gen(x)
        return x

3、训练 Train

@判别器训练

判别器的训练由两部分构成,分别是真的图片判别为真,假的图片判别为假,而在这个过程中,生成器的参与不参与更新

首先是定义loss函数的度量方式和优化函数,loss度量使用二分类的交叉熵,优化函数要注意使用的学习率是0.0003

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

然后进入训练

img = img.view(num_img, -1)      # 将图片展开乘28x28=784
real_img = Variable(img).cuda()  # 将tensor变成Variable放入计算图中
real_label = Variable(torch.ones(num_img)).cuda()  # 定义真的label为1
fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的label为0

# 计算真图片的loss
real_out = D(real_img)  # 将真实的图片放入判别器中
d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss  
real_scores = real_out  # 真实图片放入判别器输出越接近1越好

# 计算假图片的loss
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 随机生成一些噪声
fake_img = G(z)          # 将一张假的图片放入生成网络中
fake_out = D(fake_img)   # 判别器判断假的图片
d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
fake_scores = fake_out   # 假的图片放入判别器越接近0越好

# bp and optimize
d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来
d_optimizer.zero_grad()  # 归0梯度
d_loss.backward()        # 反向传播
d_optimizer.step()       # 更新参数

@生成器训练

在生成网络的训练过程中,我们生成假的图片,但是希望判别器能将它识别为真的图片。我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过更新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

# 计算假图片的loss
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到随机噪声
fake_img = G(z)       # 生成假的图片
output = D(fake_img)  # 经过判别器得到结果
g_loss = criterion(output, real_label)  # 得到假的图片与真实图片label的loss

# bp and optimize
g_optimizer.zero_grad()  # 归0梯度
g_loss.backward()        # 反向传播
g_optimizer.step()       # 更新生成网络的参数

4、全部代码 (pytorch实现)

贴上程序完整的代码:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('img'):
    os.mkdir('img')

def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

# 初始化参数
batch_size = 128
num_epoch = 50
z_dimension = 100
# 对图片进行一些前期处理
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# img_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ]
# 下载数据集
mnist = datasets.MNIST(
    root='mnist_data', train=True, transform=img_transform, download=False)
# 加载数据集
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

# 判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(nn.Linear(784, 256),
                                 nn.LeakyReLU(0.2),
                                 nn.Linear(256, 256),
                                 nn.LeakyReLU(0.2), nn.Linear(256, 1),
                                 nn.Sigmoid())  # sigmoid激活函数得到一个0到1之间的概率进行二分类

    def forward(self, x):
        x = self.dis(x)
        return x


# 生成器
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh())  # Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。

    def forward(self, x):
        x = self.gen(x)
        return x


D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# 判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。
# 二进制交叉熵损失和优化器
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# 开始训练
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # ================================训练判别器===================================
        img = img.view(num_img, -1)  # # 将图片展开乘28x28=784
        # real_img = Variable(img).cuda()
        # real_label = Variable(torch.ones(num_img)).cuda()
        # fake_label = Variable(torch.zeros(num_img)).cuda()
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))  # 定义真实label为1
        fake_label = Variable(torch.zeros(num_img))  # 定义假label为1

        # 计算 real_img 的损失
        real_out = D(real_img)  # 将真实的图片放入判别器中
        d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out  # 越接近一越好

        # 计算 fake_img的损失
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # 随机生成一些噪声
        fake_img = G(z)  # 放入生成网络生成一张假的图片
        fake_out = D(fake_img)  ## 判别器判断假的图片
        d_loss_fake = criterion(fake_out, fake_label)  ## 得到假的图片的loss
        fake_scores = fake_out  # 越接近0越好

        # 反向传播和优化
        d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来
        d_optimizer.zero_grad()  # 每次梯度归零
        d_loss.backward()  # 反向传播
        d_optimizer.step()  # 更新参数

        # =================================训练生成器================================

        # 计算fake_img损失
        # z = Variable(torch.randn(num_img, z_dimension)).cuda()
        z = Variable(torch.randn(num_img, z_dimension))  # 得到随机噪声
        fake_img = G(z)  # 生成假的图片
        output = D(fake_img)  # 经过判别器得到结果
        g_loss = criterion(output, real_label)  ##得到假的图片与真实图片label的loss

        # 反向传播和优化
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, 'real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')

5、结果 Result

结果展示:

在这里插入图片描述

随着epoch的增加,可以发现产生的噪声更少了,训练也更加稳定,图片中的数字也从模糊逐渐变为清晰,epoch-49中的图片简直就像真的图片一样。

  • 1
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值