深度学习第54讲:训练一个深度卷积对抗网络DCGAN

     自从GoodFellow提出GAN以后,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。为了解决这些问题,后来的研究者不断推陈出新,以至于现在有着各种各样的GAN变体和升级网络。比如 LSGAN,WGAN,WGAN-GP,DRAGAN,CGAN,infoGAN, ACGAN,EBGAN,BEGAN,DCGAN以及最近号称史上最强图像生成网络的BigGAN等等。本节仅选取其中的DCGAN——深度卷积对抗网络进行简单讲解并利用keras进行实现。

     DCGAN的原始论文为 UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS,所谓DCGAN,顾名思义就是生成器和判别器都是深度卷积神经网络的GAN。

640?wx_fmt=png

搭建一个稳健的DCGAN要点在于:

  • 所有的pooling层使用步幅卷积(判别网络)和微步幅度卷积(生成网络)进行替换。

  • 在生成网络和判别网络上使用批处理规范化。

  • 对于更深的架构移除全连接隐藏层。

  • 在生成网络的所有层上使用ReLU激活函数,除了输出层使用Tanh激活函数。

  • 在判别网络的所有层上使用LeakyReLU激活函数。

640?wx_fmt=png

     基于DCGAN生成的卧室图片:

640?wx_fmt=png

     下面就基于keras搭建一个DCGAN。

from keras.layers import Dense, Conv2D, LeakyReLU, Dropout, Input
from keras.layers import Reshape, Conv2DTranspose, Flatten
from keras.models import Model
from keras import optimizers
import kerasimport numpy as npimport warnings
warnings.filterwarnings('ignore')

     设置相关参数:

# 潜变量维度
latent_dim = 32
# 输入像素维度
height = 32
width = 32
channels = 3

     下面开始搭建生成器网络:

generator_input = Input(shape=(latent_dim,))
x = Dense(128 * 16 * 16)(generator_input)
x = LeakyReLU()(x)
x = Reshape((16, 16, 128))(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = Model(generator_input, x)
generator.summary()

     生成器网络概要如下:

640?wx_fmt=png

     然后搭建判别器网络:

discriminator_input = Input(shape=(height, width, channels))
x = Conv2D(128, 3)(discriminator_input)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x)
discriminator = Model(discriminator_input, x)
discriminator.summary()
discriminator_optimizer = optimizers.RMSprop(lr=0.0008, 
                                             clipvalue=1.0, 
                                             decay=1e-8)

discriminator.compile(optimizer=discriminator_optimizer,
                      loss='binary_crossentropy')

     判别器网络概要如下:

640?wx_fmt=png

     将生成器网络和判别器网络进行组合成DCGAN:

# 将判别器参数设置为不可训练
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
# 搭建对抗网络
gan = Model(gan_input, gan_output)
gan_optimizer = optimizers.RMSprop(lr=0.0004, 
                                   clipvalue=1.0, 
                                   decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

     DCGAN搭建完成之后,我们使用CIFAR-10数据来进行训练,构建训练代码如下:

import os
from keras.preprocessing import image
# 加载cifar-10数据
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
# 指定青蛙图像(编号为6)
x_train = x_train[y_train.flatten() == 6]
x_train = x_train.reshape((x_train.shape[0],) +(height, width, channels)).astype('float32') / 255.
iterations = 10000

batch_size = 20

save_dir = './image'

start = 0
for step in range(iterations):    
    # 潜在空间随机采样
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))    # 解码生成虚假图像
    generated_images = generator.predict(random_latent_vectors)
    stop = start + batch_size
    real_images = x_train[start: stop]    
    # 将虚假图像和真实图像混合
    combined_images = np.concatenate([generated_images, real_images])    # 合并标签,区分真实和虚假图像
    labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])    
    # 向标签中添加随机噪声
    labels += 0.05 * np.random.random(labels.shape)    
    # 训练判别器
    d_loss = discriminator.train_on_batch(combined_images, labels)    
    # 潜在空间随机采样
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))    
    # 合并标签,以假乱真
    misleading_targets = np.zeros((batch_size, 1))    
    # 通过gan模型来训练生成器模型,冻结判别器模型权重
    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
    start += batch_size    
    if start > len(x_train) - batch_size:
        start = 0
    # 每100步绘图并保存
    if step % 100 == 0:
        gan.save_weights('gan.h5')
        print('discriminator loss:', d_loss)
        print('adversarial loss:', a_loss)
        img = image.array_to_img(generated_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))
        img = image.array_to_img(real_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

     训练过程如下:

640?wx_fmt=png

     DCGAN生成的青蛙图片和真实图片混在一起如下图所示,能否辨别出哪张是真实样本,哪张是DCGAN生成的样本?

640?wx_fmt=png

     受限于CIFAR-10数据本身的低像素性,DCGAN生成出来的图像虽然也很模糊,但基本上足以达到以假乱真的水平。上图图片中,每一列有两张是生成样本,有一张是真实样本,按列第2、1、3和2张图片是真实样本,其余都是DCGAN伪造出来的青蛙图片。

     以上便是本节内容。

参考资料:

UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS

https://blog.csdn.net/liuxiao214/article/details/74502975

thttp://www.twistedwg.com/2018/01/31/Various-GAN.html

Deep Learning with Python

往期精彩:


一个数据科学从业者的学习历程

640?

640?wx_fmt=jpeg

长按二维码.关注机器学习实验室

640?wx_fmt=jpeg

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值