keras学习之:全网最简单GAN 网络研究,教你产生手写的数字

导包

import keras
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras.utils import to_categorical
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

数据集(手写数字识别)

(x_train,y_train),(x_test,y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 28, 28,1)
x_test = x_test.reshape(10000, 28, 28,1)

挑选数据:使用 “1” 这个类别的图片作为基准

indexes = np.where(y_train==1)          ## “1” 类别的图片的所有索引

real_img_count = indexes[0].shape[0]    ## 查看选出来的图片的数量 --> 6742print(real_img_count)

target_imgs = x_train[indexes]          ## 通过索引把 x_train 中所有的 “1” 拿出来
target_labels = y_train[indexes]        ## 通过索引把 y_train 中所有的 “1” 拿出来

plt.imshow(target_imgs[2].reshape(28,28))  # 画个图看一下我们是不是真的拿到了想要的图片
  • 最后一行的 reshape 只是因为 plt.imshow 要求的输出格式的原因。
    在这里插入图片描述

核心代码

1.生成器

def generator():
    noise = Input((100,))			
    x = Dense(2048,activation=None)(noise)
    x = LeakyReLU()(x)
    x = BatchNormalization()(x)
    x = Dense(1024,activation=None)(x)
    x = LeakyReLU()(x)
    x = BatchNormalization()(x)
    x = Dense(784,activation=None)(x)
    x = LeakyReLU()(x)
    x = Reshape((28,28,1))(x)
    return Model(noise,x)
gen = generator()
  • 看网上很多教程,把 generator 写的很复杂,还通过卷积层和 upsampling 进行上采样,不利于初学者理解,其实 generator 的构造就注意三点即可:
    1. 一定要通过各种网络层产生足够的参数用来训练(既可以用卷积层也可以粗暴的用全连接层,这样对初学者搭建网络很友好),比如本文为了掩饰简单粗暴就使用了 Dense 层的堆积
    2. 保证最后一层全连接层的参数能够被 reshape 成你想要产生的图片的维度,比如我们要产生的图片维度是 28*28*1,那么我们就最后一个 Dense 层设计为 784,如果你要产生的图片维度是 32*32*3,那最后一个Dense层就设计为 32*32*3 = 3072。再通过最后的 Reshape 层进行约束即可
    3. 最后一点是防止梯度的消失问题,尽量使用 LeakyRelu 而不用 sigmoid 或者其他激活函数(这是公认的经验);为了收敛迅速使用 BatchNormalization 层锦上添花
  • 因为后面,generator 要包在 GAN 网络里面,通过 GAN 网络的训练来迭代 generator 的参数,所以 generator 不需要单独训练,因此不用在这里 compile generator
  • 我的 generator 的参数量我控制在 300 万左右,如下图所示:
    在这里插入图片描述

2. 鉴别器

def discriminator():
    inputs = Input((28,28,1))
    x = Conv2D(filters=64,kernel_size=3,strides=2,padding='valid')(inputs)
    x = Conv2D(filters=128,kernel_size=3,strides=2,padding='valid')(x)
    x = Conv2D(filters=256,kernel_size=3,strides=2,padding='valid')(x)
    x = Flatten()(x)
    x = Dense(512,activation='relu')(x)
    x = Dense(1,activation='sigmoid')(x)
    
    return Model(inputs,x)
dis = discriminator()
dis.compile(optimizer=Adam(0.0001),loss=binary_crossentropy,metrics=['accuracy'])
  • 鉴别器就更简单了,你只需要保证两点即可:

    1. 输入的维度是你图片的维度
    2. 产生足够的参数来进行训练
    3. 最后一个 Dense 层只有 1 个神经元;因为 discriminator 的作用就是把输入分类成 0 或者 1,他认为真的样本他就给他 1,他认为假的样本就给 0;因此最后一层肯定要用 sigmoid 函数归一化到 0~1 之间
  • 因为 discriminator 需要单独训练,因此在这里先 compile 一下

  • 我把鉴别器的参数量控制在 100 万左右。

3. GAN 网络的构造

batch_size = 64			
epochs = 10000

generated_img = []								# 每训练100个epochs就把一张图片放入这个列表,看整个过程产生的图片的差别
discriminator_loss = []							# 用来存放每个 epoch 的鉴别器的 loss
gan_loss = []									# 用来存放每个 epoch 的 GAN 网络的 loss,其实就代表生成器的 loss,因为训练 GAN 的时候冻结了鉴别器的参数
for epoch in range(epochs):

    ## 随机产生噪声,每个 epoch 产生一个 batchsize 这么多的 (100,)的噪声
    noise = np.random.random((batch_size,100))
    
	## 每个 epoch 都用生成器将噪声生成图片,这个图片是 fake_imgs,对应的标签就是 0,所以每个 epoch 标签的个数就是 batch_size 个
	
    fake_img = gen.predict(noise)				 # fake_img 的维度 --> (batchsize,28,28,1)
    fake_img_labels = np.zeros((batch_size,1))   # ---> 0 ,维度-->(batchsize,1)
    
    ## 每个 epoch 都从我们选出的 “1” 的所有图片中挑选随机的 batchsize 张图作为 real_imgs
    pick_indexes = np.random.randint(0,real_img_count,(batch_size,))			# np.random.randint(0,6742,(64,))
    real_img = target_imgs[pick_indexes]		# real_imgs 的维度 --> (batchsize,28,28,1)
    real_img_labels = np.ones((batch_size,1))    # ---> 1,维度-->(batchsize,1)
    
    ## 结合 fake_imgs 和 real_imgs 以及 他们的 label
    data = np.concatenate([fake_img,real_img])
    labels = np.concatenate([fake_img_labels,real_img_labels])
    
    ## 训练 discriminator
    loss = dis.train_on_batch(data,labels)
    discriminator_loss.append(loss[-1])				#每一个 batchsize中取最后一个 loss 作为这个 epoch 的loss
    print("discriminator_loss:%f"%loss[-1])

    ## 训练 gan 网络,更新 generator 的参数
    
    loss_gan = ga.train_on_batch(noise,real_img_labels)
    print("loss_gan:%f"%loss_gan[-1])
    
    gan_loss.append(loss_gan[-1])
    print("epoch:%d" % epoch + "========")
   
    
    ## 每 100 个 epoch 就存一张图到 generated 列表中,10000个 epoch 就相当于存了100张图片进去
    if epoch%100 == 0:
        im = fake_img[0].reshape((28,28))
        generated_img.append(im)
    

打印最后生成的结果

1. 前10张图的生成效果

for i in range(0,10):
    plt.subplot(2,5,i+1)		# 只打印前十张图
    plt.imshow(generated_img[i])

在这里插入图片描述

2. 最后10张图的生成效果

for i in range(90,100):
    plt.subplot(2,5,i-90+1)		## 只打印最后十张图
    plt.imshow(generated_img[i])

在这里插入图片描述

  • 可以看到,我们已经几乎成功地完成了 “1” 类图片的生成

打印训练过程的 loss

plt.plot(discriminator_loss)
plt.plot(gan_loss)

在这里插入图片描述

  • 可以看看到,GAN 在训练的时候,loss 一直在震荡,没有稳定的收敛,虽然没有收敛,但是依然产生了很好效果的数据,至于如何解决 loss 震荡和收敛的问题,大家可以去查阅资料作进一步的了解。

整体代码

  • 为了给大家更加直观的整体效果,在这里附上整体的代码,我使用的是:
    • jupyter notebook 来实现的
    • keras-gpu 2.3.1;
    • GTX1080Ti,
    • 10000次 epoch 只需要 5 分钟就可以跑完~~
    import keras
    from keras.layers import *
    from keras.optimizers import *
    from keras.losses import *
    from keras.utils import to_categorical
    from keras.models import Model
    from keras.datasets import mnist
    import numpy as np
    import matplotlib.pyplot as plt
    
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    
    x_train = x_train.reshape(60000, 28, 28,1)
    x_test = x_test.reshape(10000, 28, 28,1)
    
    # y_train = to_categorical(y_train)
    # y_test = to_categorical(y_test)
    
    ### 第一类的索引,选出第一类的数据以及其 labels
    
    indexes = np.where(y_train==1)          ## “1” 类别的图片的所有索引
    
    real_img_count = indexes[0].shape[0]    ## 查看选出来的图片的数量 --> 6742print(real_img_count)
    target_imgs = x_train[indexes]          ## 通过索引把 x_train 中所有的 “1” 拿出来
    target_labels = y_train[indexes]        ## 通过索引把 y_train 中所有的 “1” 拿出来
    
    plt.imshow(target_imgs[2].reshape(28,28))  # 画个图看一下我们是不是真的拿到了想要的图片
    
    ### 建立 GAN 网络,用来生成第一类的图片
    
    def generator():
        noise = Input((100,))
        x = Dense(2048,activation=None)(noise)
        x = LeakyReLU()(x)
        x = BatchNormalization()(x)
        x = Dense(1024,activation=None)(x)
        x = LeakyReLU()(x)
        x = BatchNormalization()(x)
        x = Dense(784,activation=None)(x)
        x = LeakyReLU()(x)
        x = Reshape((28,28,1))(x)
        return Model(noise,x)
    
    gen = generator()
    
    
    def discriminator():
        inputs = Input((28,28,1))
        x = Conv2D(filters=64,kernel_size=3,strides=2,padding='valid')(inputs)
        x = Conv2D(filters=128,kernel_size=3,strides=2,padding='valid')(x)
        x = Conv2D(filters=256,kernel_size=3,strides=2,padding='valid')(x)
        x = Flatten()(x)
        x = Dense(512,activation='relu')(x)
        x = Dense(1,activation='sigmoid')(x)
        
        return Model(inputs,x)
    
    dis = discriminator()
    dis.compile(optimizer=Adam(0.0001),loss=binary_crossentropy,metrics=['accuracy'])
    
    def gan():
        inputs = Input((100,))
        fake_img = gen(inputs)
        dis.trainable=False
        score = dis(fake_img)
        return Model(inputs,score)    ## 这个 score 越接近1越好
    
    ga = gan()
    ga.compile(optimizer=RMSprop(0.0001),loss=binary_crossentropy,metrics=['accuracy'])
    
    
    batch_size = 64
    epochs = 10000
    
    generated_img = []
    discriminator_loss = []
    gan_loss = []
    
    for epoch in range(epochs):
        ## 随机产生噪声
        noise = np.random.random((batch_size,100))
        
        fake_img = gen.predict(noise)
        fake_img_labels = np.zeros((batch_size,1))   # ---> 0 
        
        pick_indexes = np.random.randint(0,real_img_count,(batch_size,))
        real_img = target_imgs[pick_indexes]
        real_img_labels = np.ones((batch_size,1))    # ---> 1
        
        ## 结合 fake_imgs 和 real_imgs 以及 他们的 label
        data = np.concatenate([fake_img,real_img])
        labels = np.concatenate([fake_img_labels,real_img_labels])
        
        ## 训练 discriminator
        
        loss = dis.train_on_batch(data,labels)
        discriminator_loss.append(loss[-1])
        print("discriminator_loss:%f"%loss[-1])
    
        
        ## 训练 gan 网络,更新 generator 的参数
        
        loss_gan = ga.train_on_batch(noise,real_img_labels)
        print("loss_gan:%f"%loss_gan[-1])
        
        gan_loss.append(loss_gan[-1])
        print("epoch:%d" % epoch + "========")
    
    
        
        if epoch%100 == 0:
            im = fake_img[0].reshape((28,28))
            generated_img.append(im)
        
    
    
    for i in range(0,10):
        plt.subplot(2,5,i+1)
        plt.imshow(generated_img[i])
    
    for i in range(90,100):
        plt.subplot(2,5,i-90+1)
        plt.imshow(generated_img[i])
    
    
    plt.plot(discriminator_loss)
    plt.plot(gan_loss)
    
    

牛刀小试:生成 fashion_mnist数据集中的图片

  • 结果还是很好的~
  • 一共10000个 epoch
  • 每1000 个 epoch 画一张图

结果

在这里插入图片描述

loss

在这里插入图片描述

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值