GAN通过Keras生成mnist

GAN介绍

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN网络中最少有两个模块,分别是生成模型和判别模型。而GAN就是让这两个模块相互对抗。举个简单的例子:小明是一个画家,专门画假画的那种,小白是个鉴别师,需要鉴别出假的画作。小明的目的就是能够画出小白无法鉴别出真假的画作,而小白就要保证能够检验出假货。最开始可能小明的技术拙劣,小白一下就看出来了,但是小明在一次又一次的绘画中不断的学习,到某个时间小白发现无法鉴别出画作的真伪,于是小白也跑去学习,那么在这个不断学习,不断鉴别的过程中,小白和小明最终都达到一个动态平衡的时刻,那么这时候,小明的画作已经不能被小白鉴别出来。这就是生成式对抗网络。
那么,我们通过构建一个生成式对抗网络来尝试生成一下mnist数据集叭~~

Generator

# 构建生成模型
def build_generator(self):
	model = Sequential()
	model.add(Dense(256, input_dim=self.latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(momentum=0.8))
	model.add(Dense(512))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(momentum=0.8))
	model.add(Dense(1024))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(momentum=0.8))
	model.add(Dense(np.prod(self.img_shape), activation='tanh'))
	model.add(Reshape(self.img_shape))
	model.summary()
	noise = Input(shape=(self.latent_dim,))
	img = model(noise)
	return Model(noise, img)

可以看出,这里生成器模块通过一个简单的神经网络,将输入的100维的数据最后输出成28281的图片。首先我们用的是LeakyReLU而不是RelU,是因为ReLU在当输入值为负的时候,输出始终为0,其一阶导数也始终为0,这样会导致神经元不能更新参数,也就是神经元不学习了,这种现象叫做“Dead Neuron”。而为了尽量避免这种情况出现,我们使用了LeakyReLU,该函数输出对负值输入有很小的坡度。由于导数总是不为零,这能减少静默神经元的出现,允许基于梯度的学习(虽然会很慢),解决了Relu函数进入负区间后,导致神经元不学习的问题。

Adversarial

def build_discriminator(self):
	model = Sequential()
	model.add(Flatten(input_shape=self.img_shape))
	model.add(Dense(512))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Dense(256))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Dense(1, activation='sigmoid'))
	model.summary()
	img = Input(shape=self.img_shape)
	validity = model(img)
	return Model(img, validity)

判别模型的目的是根据输入的图片判断出真伪。因此它的输入一个28281维的图片,输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。所以最后模型输出的是一个1维的数组。

训练思路

在每一轮迭代中,网络要做两件事
训练Generator,保持Adversarial不训练
首先生成一个100维的随机数,之后用Generator生成图片
然后把生成的假图片与真图片混合,真图片 label 为 1, 假图片 label 为 0
然后把组合的图片数据交由Adversarial进行判断,得到一个二分类的分类器
训练Adversarial,保持Generator不训练
此时我们已经有一个暂时还可以的D,那么G的作用就是要把D给骗过去,让D对G生成的图片输出1。
首先生成一个100维的随机数,之后用Generator生成图片
然后把label 设为1,交由Generator进行判断;
此时Generator的输出肯定不是 1,因为这个Generator我们已经训练过了,还算靠谱,对于假的图片输出肯定小于 1,记为score;
然后1-score 就是loss,然后调整 G 的参数,得到一个更好的Generator
最后经过一定的迭代之后,Generator能够输出一个足以以假乱真的图片。

实现的全部代码(复制)

PS:参考博客https://blog.csdn.net/weixin_44791964/article/details/103729797

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np


class GAN():
    def __init__(self):
        self.img_rows = 28  # 图片大小
        self.img_cols = 28  # 图片大小
        self.channels = 1   # 由于生成的是灰度图,所以维度为1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)  # 设置shape参数
        self.latent_dim = 100   # 输入100维度服从高斯分布的向量
        optimizer = Adam(0.0002, 0.5)   # 使用Adam的优化器
        # 构建和编译判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])
        # 构建生成器
        self.generator = self.build_generator()
        # 生成器输入噪音,生成假的图片
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)
        # 为了组合模型,只训练生成器
        self.discriminator.trainable = False
        # 判别器将生成的图像作为输入并确定有效性
        validity = self.discriminator(img)
        # 训练生成器骗过判别器
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
        model.summary()
        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
        return Model(noise, img)

    def build_discriminator(self):
        model = Sequential()
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        img = Input(shape=self.img_shape)
        validity = model(img)
        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 加载数据集
        (X_train, _), (_, _) = mnist.load_data()
        # 归一化到-1到1
        X_train = X_train / 127.5 - 1.
        print(X_train.shape)
        X_train = np.expand_dims(X_train, axis=3)
        print(X_train.shape)
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        for epoch in range(epochs):
            # ---------------------
            #  训练判别器
            # ---------------------
            # X_train.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            # 从数据集随机挑选batch_size个数据,作为一个批次训练
            imgs = X_train[idx]
            # 噪音维度(batch_size,100)
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 由生成器根据噪音生成假的图片
            gen_imgs = self.generator.predict(noise)
            # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            # ---------------------
            #  训练生成器
            # ---------------------
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)
            # 打印loss值
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
            # 没sample_interval个epoch保存一次生成图片
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                if not os.path.exists("keras_model"):
                    os.makedirs("keras_model")
                self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch, True)
                self.discriminator.save_weights("keras_model/D_model%d.hdf5" % epoch, True)

    def sample_images(self, epoch):
        r, c = 5, 5
        # 重新生成一批噪音,维度为(25,100)
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        # 将生成的图片重新归整到0-1之间
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        if not os.path.exists("keras_imgs"):
            os.makedirs("keras_imgs")
        fig.savefig("keras_imgs/%d.png" % epoch)
        plt.close()

    def test(self, gen_nums=100):
        self.generator.load_weights("keras_model/G_model15000.hdf5", by_name=True)
        self.discriminator.load_weights("keras_model/D_model15000.hdf5", by_name=True)
        noise = np.random.normal(0, 1, (gen_nums, self.latent_dim))
        gen = self.generator.predict(noise)
        print(gen.shape)
        # 重整图片到0-1
        gen = 0.5 * gen + 0.5
        for i in range(0, len(gen)):
            plt.figure(figsize=(128, 128), dpi=1)
            plt.imshow(gen[i, :, :, 0], cmap="gray")
            plt.axis("off")
            if not os.path.exists("keras_gen"):
                os.makedirs("keras_gen")
            plt.savefig("keras_gen" + os.sep + str(i) + '.jpg', dpi=1)
            plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=1000)
    gan.test()

最后训练的结果

在这里插入图片描述

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值