【代码篇】【1】详解GAN代码生成mnist图片(keras)

本文详细介绍了如何使用Keras实现DCGAN(深度卷积对抗生成网络)来生成MNIST手写数字图像。文章涵盖了GAN的基本概念、训练过程、DCGAN的特点以及基于Keras的生成器和判别器的网络结构。通过交替训练生成器和判别器,最终达到生成逼真图像的效果。
摘要由CSDN通过智能技术生成

【代码篇】【1】详解GAN代码生成mnist图片(keras)

在这里插入图片描述

0.GAN的基本概念

GAN (Generative Adversarial Networks)从其名字可以看出,是一种生成式的对抗网络。再具体一点, 就是通过对抗的方式,去学习数据分布的生成式模型。所谓的对抗,指的是生成网络和判别网络的互相对抗。生成网络尽可能生成逼真的样本,判别网络则尽可能去判别该样本是真实样本,还是生成的假样本。

在这里插入图片描述
隐变量z (通常为服从高斯分布的随机噪声)通过Generator生成X fake,判别器负责判别输入的data是生成的样本X fake还是真实样本Xreal。优化的目标函数如下:
在这里插入图片描述
对于判别器D来说,这是一个二分类问题,V(D, G)为二分类问题中常见的交叉熵损失。对于生成器G来说,为了尽可能欺骗D,所以需要最大化生成样本的判别概率D(G(z)),即最小化log(1 - D(G())),log(D(x))- -项与生成器G无关,可以忽略。

1.如何训练GAN?

实际训练时,生成器和判别器采取交替训练,即先训练D,然后训练G,不断往复。值得注意的是,对于生成器,其最小化的是max V(D, G),即最小化V(D, G)的最大值。为了保证V (D, G)取得最大值,所以我们通常会训练迭代k次判别器,然后再迭代1次生成器(不过在实践当中发现,k通常取1即可)。当生成器G固定时,我们可以对V(D, G)求导,求出最优判别器D* (x):
在这里插入图片描述
把最优判别器代入上述目标函数,可以进一步求出在最优判别器下,生成器的目
标函数等价于优化Pdata (x), Pg(x )的JS散度(JSD, Jenson Shannon Divergence)。可以证明,当G, D二者的capacity足够时,模型会收敛,二者将达到纳什均衡。此时,Pdata(x)= Pg(x),判别器不论是对于pdata (x)还是pg(x)中采样的样本,其预测概率均为0.5,即生成样本与真实样本达到了难以区分的地步。

3.GAN的常见模型DCGAN

DCGAN的全称是Deep Convolutional Generative Adversarial Networks ,意即深度卷积对抗生成网络。它是由Alec Radford在论文Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks中提出的。实际上它是在GAN的基础上增加深度卷积网络结构。

DCGAN提出使用CNN结构来稳定GAN的训练,并使用了以下一些trick:
在这里插入图片描述

4.基于keras的DCGAN

代码主要是参考Bubbliiiing同学的代码,稍微的修改了网络结构,一共分为5个模块讲解:
Keras搭建DCGAN利用深度卷积神经网络实现图片生成

4.1生成器generator

生成器的输入为基于正态分布的N维向量noise,输出为28x28x1的mnist图片。
首先输入noise,全连接到7x7x16大小( latent_dim —> 7x7x16)。
然后使用多次上采样+卷积+batchnorm+relu模块,直到28x28x1大小。
值得注意的是,最后一层使用tanh激活函数,效果要好些。
model.summary()为打印模型。

    def make_generator(self):
        #----------------------------#
        #      make generator        #
        #----------------------------#
        model = Sequential()
        # latent_dim ---> 7x7x16
        model.add(Dense(7*7*16, activation='relu', input_dim=self.latent_dim))
        model.add(Reshape((7,7,16)))
        # 7x7x16 ---> 7x7x32
        model.add(Conv2D(32, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        # 7x7x32 ---> 14x14x64
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))        
        # 7x7x64 ---> 28x28x128
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))  
        # 28x28x128 ---> 28x28x32
        model.add(
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值