不服就GAN:keras --- GAN 网络生成手写数字实例,生成数字 6

代码

1. 导包

import keras,os
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from keras.preprocessing import image


from keras.datasets import fashion_mnist,cifar10,cifar100,mnist
from keras.utils import to_categorical

os.environ["CUDA_VISIBLE_DEVICES"] = " 2"

2. 鉴别器和生成器的定义、创建

def generator(input_shape):
    inputs = Input(input_shape)
    # 先全连接到64*7*7的维度上
    x = Dense(128 * 14 * 14)(inputs)
    x = LeakyReLU(0.2)(x)
    x = Reshape((14, 14, 128))(x)

    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2DTranspose(256, 4, strides = 2, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2D(1, 7, activation='tanh', padding = 'same')(x)
    return Model(inputs,x)
                  


def discriminator(input_shape):
                  
    inputs = Input(input_shape)
    # 28, 28, 1 -> 14, 14, 32
    x = Conv2D(128, 3)(inputs)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128,4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128,4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128, 4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)

    x = Flatten()(x)

    x = Dropout(0.4)(x)
    x = Dense(1, activation='sigmoid')(x) #分类层

    return Model(inputs,x)


gen = generator((100,))

dis = discriminator((28,28,1))

dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))

3. 联合生成器和鉴别器创建 GAN 网络

def GAN():
    gan_input = Input((100,))
    fake_image = gen(gan_input)
    dis.trainable=False
    score = dis(fake_image)
    return Model(gan_input,score)

gan = GAN()
gan.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8))

4. 数据导入+规范化

(x_train,y_train),(x_test,y_test)= mnist.load_data()

x_train = x_train[y_train.flatten() == 6] 

x_train = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')/255.

5. 训练

epochs = 10000
batch_size = 64


generated_img = []
discriminator_loss = []
generator_loss = []
save_dir = './A-GAN-PHOTO'

start = 0

for epoch in range(epochs):
    
    noise = np.random.normal(size=(batch_size,100))
    stop = start + batch_size
    
    real_img = x_train[start:stop]
    
    fake_img = gen.predict(noise)
    
    data = np.concatenate([fake_img, real_img])
    
    valid = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))
    
    label = np.concatenate([fake,valid])
    label += 0.05 * np.random.random(label.shape)  ## 训练时加入噪声
    
    d_loss = dis.train_on_batch(data,label)

    # ---------------------
    #  训练生成模型
    # ---------------------
    noise_ = np.random.normal(size=(batch_size,100))
    g_loss = gan.train_on_batch(noise_, valid)
    
#     dis.trainable=True
#     dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))
    
    start += batch_size
    if start > len(x_train) - batch_size:
        start = 0
        

    if epoch%100 == 0:
#         im = fake_img[0].reshape((28,28))
        im = fake_img[0]
        
#         im = fake_img[0].reshape(32,32,3)
        generated_img.append(im)
        img = image.array_to_img(im * 255, scale=False)
        img.save(os.path.join(save_dir, 'fake_six' + str(epoch) + '.png'))	#保存一张生成图像

        img = image.array_to_img(real_img[0] * 255, scale=False)
        img.save(os.path.join(save_dir, 'real_six' + str(epoch) +'.png'))   #保存一张真实图像用于对比

        print('discriminator_loss:',d_loss)
        print('adversal_loss:',g_loss)
        discriminator_loss.append(d_loss)
        generator_loss.append(g_loss)
        print("epoch:%d" % epoch + "========")

6. 可视化

fig, axes = plt.subplots(nrows=5, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = [i.reshape(28,28) for i in generated_img]
# imgs = generated_img

for image, row in zip([imgs[:20], imgs[20:40],imgs[40:60],imgs[60:80],imgs[80:100]], axes):
# for image, row in zip([imgs[0:10],imgs[5:10]], axes):
    
    for img, ax in zip(image, row):
        ax.imshow(img)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

fig.tight_layout(pad=0.1)

在这里插入图片描述

plt.plot(discriminator_loss,label='discriminator_loss')
plt.plot(generator_loss,label='generator_loss')
plt.legend()

在这里插入图片描述

fig, axes = plt.subplots(nrows=5, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = [i.reshape(28,28) for i in generated_img]
# imgs = generated_img

for image, row in zip([imgs[:20], imgs[20:40],imgs[40:60],imgs[60:80],imgs[80:100]], axes):
# for image, row in zip([imgs[0:10],imgs[5:10]], axes):
    
    for img, ax in zip(image, row):
        ax.imshow(img)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

fig.tight_layout(pad=0.1)

在这里插入图片描述

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值