生成对抗网络GAN的keras实例

该博客详细介绍了如何使用Keras构建和训练生成对抗网络(GAN)。首先,定义了优化器Adam,然后分别构建了鉴别器和生成器模型。鉴别器采用多层全连接网络,生成器则包含多个LeakyReLU激活的密集层。接下来,将这两个模型组合成一个完整的GAN模型,并进行了训练。训练过程中,不断迭代更新生成器和鉴别器,以生成接近真实数据的虚假图片。最后,展示了训练过程中生成的虚假图片与真实图片的对比。
摘要由CSDN通过智能技术生成

生成对抗网络GAN的keras实例

导入一些需要的包

from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

定义优化器

optimizer = Adam(0.0002, 0.5)

构建鉴别器并编译

n_y_value = 20
D = Sequential()
D.add(Dense(512))
D.add(LeakyReLU(alpha=0.2))
D.add(Dense(256))
D.add(Dense(1, activation='sigmoid'))
# D.summary()
img = Input(shape=(n_y_value,))
validity = D(img)
Discriminator = Model(img,validity)
Discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

构建生成器,并组合生成器和鉴别器成GAN

N_ideas = 5
G = Sequential()
G.add(Dense(512,input_dim=N_ideas))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(512))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(1024))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(n_y_value, activation='tanh'))
# G.add(Reshape(n_y_value))
# G.summary()
noise = Input(shape=(N_ideas,))
G_img = G(noise)
Generator = Model(noise,G_img)

z = Input(shape=(N_ideas,))
G_img = Generator(z)
Discriminator.trainable = False
validity = Discriminator(G_img)
GAN = Model(z,validity)
GAN.compile(loss='binary_crossentropy', optimizer=optimizer)

训练过程

batch_size=64
x= np.vstack([np.linspace(-1,1,n_y_value) for _ in range(batch_size)])
true_imgs =np.power(x,2)



valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
plt.ion()
for i in range(1600):
    noise = np.random.normal(0, 1, (batch_size, N_ideas))
    G_imgs = Generator.predict(noise)

    d_loss_fake = Discriminator.train_on_batch(G_imgs,fake)
    d_loss_real = Discriminator.train_on_batch(true_imgs,valid)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    g_loss = GAN.train_on_batch(noise,valid)
    print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (i, d_loss[0], 100 * d_loss[1], g_loss))
    # print("G_imgs.shape:",G_imgs.shape)  #(64,20)
    plt.cla()

    plt.xlim((-1.2, 1.2))
    plt.ylim((-0.2, 1.2))
    plt.plot(x[0], true_imgs[0], lw=2, c='#11AAAA')
    plt.plot(x[0],G_imgs[0], lw=2, c='#B62A2A')
    plt.pause(0.01)

plt.ioff()
plt.show()

最终网络生成的虚假图片与真实图片如下
红色为虚假生成的二次函数曲线,蓝色为真实曲线

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值