GAN生成式对抗网络简介及MINST实现

一、什么是GAN生成式对抗网络

1.1.GAN的简介

生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

1.2.GAN的组成

模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。

训练完毕后的GAN模型是一个输入噪声即可得到想要的图片的网络结构。
在这里插入图片描述

1.3.GAN的发展历程

Ian J. Goodfellow等人于2014年10月在Generative Adversarial Networks中提出了一个通过对抗过程估计生成模型的新框架,被称作GAN。

GAN经过后续的发展,在多领域被广泛应用。

图像领域中,DCGAN(Deep convolutional NN for GAN)是一种类似反卷积的结构,采用一个随机噪声向量作为输入,如高斯噪声,通过与CNN相反的结构将输入放大成二维数据。

半监督学习中,GAN也有自己的应用。

模型改进方面,WGAN是GAN的改进版本,在GAN的训练中,判断模型训练的非常好的时候,可能生成模型结果很坏,GAN的本质其实是优化真实样本分布和生成样本分布之间的差异,并最小化这个差异。优化的目标函数是两个分布上的Jensen-Shannon距离,但这个距离有这样一个问题,如果两个分布的样本空间并不完全重合,这个距离是无法定义的。作者接着证明了“真实分布与生成分布的样本空间并不完全重合”是一个极大概率事件,并证明在一些假设条件下,可以从理论层面推导出一些实际中遇到的现象。该文章提出了一种解决方案:使用Wasserstein距离代替Jensen-Shannon距离。并依据Wasserstein距离设计了相应的算法,即WGAN。新的算法与原始GAN相比,参数更加不敏感,训练过程更加平滑。

二、GAN的原理

2.1.生成器(generator)

随机生成一个噪声,比如1000维向量,传入生成器,通过反卷积的方式,将其转化为xx3的图片。

在这里插入图片描述

2.2.判别器(discriminator)

就是一个分类模型,输入一个图,判别是真图还是假图。一般采用卷积的方式(建议理解成一个CNN)。

2.3.训练过程

首先是初始化生成器。随即输入一些噪声,生成一些图片,为这些图片打上假图标签,之后取一些真图,打上真图标签。(这里的真图是希望模型生成的事物,本质上就是让计算机模仿真图自己画出来图)

之后是用生成器训练判别器。将生成的假图和真图放入判别器中进行学习,使得判别器对于真图假图有比较高的判别能力。

然后用判别器训练生成器,让生成器不断生成假图,放到判别器中判断是真图还是假图,不断调整参数,直到判别器判断的真图假图之比是一比一(这里不是很理解,判别器判断的都是真图不是更能说明生成器很合适吗?),此时说明生成器已经可以以假乱真,应该重新提高判别器精度了。

最后,重复二三步骤,直到生成模型生成的图片质量达到了要求(这里埋下伏笔,怎么衡量模型生成的图片质量高不高,至今未被解决),同时,也有说法是达成了纳什均衡(纳什均衡指无法只通过改变某一模型使得该模型更好,本质上是不可能达到的,因为给定一个生成器,无论何时,你都无法保证判别器无法通过训练使得判断真图假图的准确率变得更高,只能说相对稳定,或者我们没有提高准确率的办法)
在这里插入图片描述

2.4.使用过程

GAN的使用过程就是感觉最后一轮训练得到的图形不错之后,随机加入一些噪声到生成器,得到想要的数据。也就是GAN最终得到的是一个可以通过随机数据生成图片的模型。

2.5.GAN的本质

GAN的本质其实是优化真实样本分布和生成样本分布之间的差异,说人话就是通过学习真东西,自己造出假东西,在不断实践中修正自己,直到造出很像真东西的假东西。

三、如何应用GAN

3.1.明确目的

是需要生成更多数据集以训练分类模型,还是只是生成一些图片看看效果。

通过GAN生成数据集用于训练在数据集上的分类模型是非常困难的,首先,要为生成的数据筛选并打标签,这需要人工,但也可以考虑训练多个生成器,每个生成器生成一类,但训练过程会非常繁琐。之后,要保证生成的图片对数据集的分类训练有帮助,按照逻辑思维想,GAN网络本质就是,用一些已知的量通过一些花哨的变化提取一些特征,GAN利用这些特征生成一些比较逼真的图片。此时,如果我们希望通过筛选逼真的图片手打标签,再次用分类网络提取其他特征使得网络分类更精准,那就类似于人根据自己的认知写数据集;如果我们要通过多个GAN提取图像特征,分类生成不同类的图片用于分类网络的训练,那无异于用GAN提取的特征拼接出的图片,再次提取特征,这是愚蠢的,不如直接想办法把GAN提取到的特征用于分类模型的初始化。综上,用GAN生成的图片作为数据集重新加入训练是依赖于人的。

3.2.明确网络结构

GAN有许多变化的网络结构,应该再明确目的后进行选择确认。

3.3.明确方法

一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

四、MINST数据集上GAN的pytorch实现

import torch
from torch import nn
from torch.autograd import Variable

import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

plt.rcParams['figure.figsize'] = (10.0, 8.0)  # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'


def show_images(images):  # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    return


def preprocess_img(x):
    x = tfs.ToTensor()(x)
    return (x - 0.5) / 0.5


def deprocess_img(x):
    return (x + 1.0) / 2.0


class ChunkSampler(sampler.Sampler):  # 定义一个取样的函数

    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples


NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

train_set = MNIST('./data', train=True, transform=preprocess_img)

train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))

val_set = MNIST('./data', train=True, transform=preprocess_img)

val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))

imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze()  # 可视化图片效果
show_images(imgs)


# 判别网络
def discriminator():
    net = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1)
    )
    return net


# 生成网络
def generator(noise_dim=NOISE_DIM):
    net = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 784),
        nn.Tanh()
    )
    return net


bce_loss = nn.BCEWithLogitsLoss()  # 交叉熵损失函数


def discriminator_loss(logits_real, logits_fake):  # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    false_labels = Variable(torch.zeros(size, 1)).float()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss


def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    loss = bce_loss(logits_fake, true_labels)
    return loss


# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer


def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
                noise_size=96, num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x).view(bs, -1)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 生成网络
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1


D = discriminator()
G = generator()

D_optim = get_optimizer(D)
G_optim = get_optimizer(G)

train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

参考 https://www.jb51.net/article/178171.htm
训练结果大概这样
在这里插入图片描述
在这里插入图片描述

五、GAN的评价与推广

5.1.GAN的评价

5.1.1.GAN的优势

最大化利用数据集
生成模型和判别模型互相借鉴学习,拟合

5.1.2.GAN的缺陷

对训练的超参数特别敏感,需要精心设计
如果将判别模型训练的很充分,生成模型可能变差

5.2.GAN的推广

没有公认的标准衡量不同图片生成算法的差异性,这个问题涵待解决。

GAN生成的图片没有标签,如果有后续使用要求,需要人工筛选并标记。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值