【生成对抗网络】GAN入门与代码实现(一)

生成对抗网络系列
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(二)
【生成对抗网络】基于DCGAN的二次元人物头像生成(TensorFlow2)
【生成对抗网络】ACGAN的代码实现

1. 生成对抗网络介绍

生成对抗网络(Generative Adversarial Network)于2014年被Goodfellow等人提出,然后迅速流行。GAN能通过学习特定领域知识创造出新的图像、文本等。2016年,GAN热潮席卷人工智能领域顶级会议,从ICLR到NIPS,大量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

在GAN中主要由生成器G(generator)与判别器D(discriminator)构成。其中生成器用于生成逼真的假数据、判别器则需要在判别出真实数据与假数据,生成器与判别器相互博弈,在能力上有所提升,生成器生成的数据越来越像是真实的数据,判别器则能更好地将两者分辨出来,直到两者达到一种平衡。

假如以小狗图片作为生成的目标:

  • 生成器:接收一个随机噪声(随机变量)作为输入,输出一个小狗的图片(假图片)。
  • 判别器:将原真实的小狗图片和生成器生成的小狗图片两者区分出来,判断谁真谁假。

在模型训练的过程中:

​ 生成器:学习如何更好的将生成的小狗图片更加像真实,从而让判别器误认为是真实的。

​ 判别器:不断地将生成器生成的图片与真实的图片用于判别器模型的训练,提高自己的判别准确率。

GAN的整个训练过程如下:

  1. ​ 生成器接收随机噪声,并生成假图像;
  2. ​ 判别器接收假图像和真实图像组合的数据,学习如何判别真假图像;
  3. ​ 生成器生成新的图像,并使用判别器来判别真假,同时通过判别器来判别此次造假的水平;
  4. ​ 重复步骤 1-3。

2. 基于TensorFlow2的GAN的简单实现

我们以手写数据集MNIST为例进行演示。让GAN学习生成一些新的手写数字图片,每张图片的尺寸为28*28。
在这里插入图片描述

代码实现步骤如下:

  1. 定义生成器,接收随机噪声,输出图像张量
  2. 定义判别器,接收图像张量,输出真假张量
  3. 定义生成对抗网络,接收随机噪声,输出真假张量。生成对抗网络由前面定义的生成器的模型层和判别器的模型创建( 它们共享权重),同时需要冻结判别器的权重。
  4. 将随机噪声输入生成器,生成一批图像
  5. 使用生成的图像与真实图像训练判别器(假图像的目标为0,真图像的目标为1)
  6. 使用新随机噪声输入生成对抗网络,输出真假(使生成的假图像判别为1),提高“造假”水平
  7. 重复4-6步骤

2.1 导包与参数设置

import numpy as np # 用于数据处理
import tensorflow as tf # 版本2.0及以上
from tensorflow import keras # 主要使用keras实现
import tqdm # 进度条,使用pip install tqdm安装
import matplotlib.pyplot as plt # 绘图函数库
%matplotlib inline
LATENT_DTM = 100 # 随机噪声的长度
IMAGE_SHAPE = (28,28,1) # 手写数字图片的尺寸与通道数

2.2 生成器

生成器接收随机向量,然后通过模型生成一张手写数字图片。

关键点:

  • 使用随机噪声作为输入,保证模型具有一定的随机性
  • 使用tanh作为最后一层的激活函数,可以获得更好的效果
  • 使用LeakyReLU激活函数来代替ReLU激活函数
generator_net = [
    keras.layers.Input(shape=(LATENT_DTM,)), # 输入为长度100点随机向量
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(1024),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(np.prod(IMAGE_SHAPE),activation='tanh'),
    keras.layers.Reshape(IMAGE_SHAPE) #  将向量重塑shape为(28,28,1),输出图片
]
generator = keras.models.Sequential(generator_net)

2.3 判别器

判别器是一个二分类问题,接收一个图片,输出真假。

discriminator_net =[
    keras.layers.Input(shape=IMAGE_SHAPE),
    keras.layers.Flatten(),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(1,activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5)

模型编译:

discriminator.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc'])

2.4 搭建生成对抗网络

将生成器与判别器组合在一起,同时冻结判别器的权重。

该过程将生成器生成的图片直接送入判别器模型,从而直接输出结果。在该网络中,需要冻结判别器的权重,因为我们需要在此过程中训练生成器,让判别器的结果输出为“真”,从而不断完善生成器生成图像的水平,所以只需要训练生成器的层。

# 生成对抗网络使用生成器模型层和判别器模型层,它们共享权重。
adversarial_net = generator_net + discriminator_net

# 冻结判别器的层的权重
# trainable 属性只有编译后才生效,所以之前的判别器中同样的层还是可以训练的
for layer in discriminator_net:
    layer.trainable = False
adversarial = keras.models.Sequential(adversarial_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5) # 优化器

模型编译:

adversarial.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc']) # 模型编译

2.5 数据准备与预处理

加载keras中内置的手写数据集

(image_set,_),_ = keras.datasets.mnist.load_data() # 加载数据集
image_set = image_set/127.5 - 1 # shape为(60000,28,28)
image_set = image_set.reshape((image_set.shape[0],28,28,1)) # shape为(60000,28,28,1)

准备训练过程中可视化的随机向量seed

num_example_to_generate = 6 # 用于绘图过程中生成图片的数量

seed = np.random.normal(0,1,(num_example_to_generate,LATENT_DTM)) # 生成6个长度为100的随机向量

用于记录训练过程中的准确率与损失

# 损失
g_loss_list = [] # 生成器
d_loss_list = [] # 判别器
# 准确率
g_acc_list = [] # 生成器
d_acc_list = [] # 判别器

2.6 主训练方法

def train(batch = 30000,batch_size = 300):
    # 准备batch_size大小的真假数据标签
    valid = np.ones((batch_size)) # 全是1
    fake = np.zeros((batch_size)) # 全是0
    
    # 使用进度条tqdm库
    batch_tqdm = tqdm.trange(batch)
    for index in batch_tqdm:
        
        # 随机选择batch_size数量的数据作为训练数据
        idx = np.random.randint(0,image_set.shape[0],batch_size)
        imgs = image_set[idx] 
        
        # 生成噪声数据并作为生成器的输入
        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))
        
        # 使用生成器生成图像
        gen_imgs = generator.predict(noise)
        
        # 训练判别器
        # 使用真实图像和生成图像训练判别器,真实图像的标签全部为1,生成图像的标签全部为0
        d_state_real = discriminator.train_on_batch(imgs,valid) # 返回的是loss和acc
        d_state_fake = discriminator.train_on_batch(gen_imgs,fake)
        # 判别器在生成图像与真实图像两者的结果取平局值
        d_state = 0.5*(np.add(d_state_real,d_state_fake))
        
        # 训练判别器
        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))
        # 训练生成对抗网络,目标是生成判别器人物真实的图像,因此标签为1
        # 因为生成对抗网络中的判别器的层都冻结了,所以实际上在训练生成器,不断生成更加逼真的图像
        adv_state = adversarial.train_on_batch(noise,valid)
        
        # 更新进度条后缀文本,用于输出训练进度
        state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
                f"[G loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}]"
        batch_tqdm.set_postfix(state=state)
        # 存储损失值和准确率
        g_loss_list.append(adv_state[0])
        g_acc_list.append(adv_state[1])
        d_loss_list.append(d_state[0])
        d_acc_list.append(d_state[1])
        
        if index%500 == 0: # 每500次绘图一次
            generate_plot_image(seed) # 绘图函数,每次都用同一个随机噪声seed生成图片,可以看到数字的变化

注意model的train_on_batch方法的使用。

2.7 绘图函数

用固定的noise绘制6张图片,以便观察训练效果。

# 画图函数
def generate_plot_image(test_noise):

    pre_image = generator(test_noise,training = False) # 用生成器,生成手写图片
    # print(pre_image.shape) # (6,28,28,1)
    fig = plt.figure(figsize=(16,3)) # figsize:指定figure的宽和高,单位为英寸
    for i in range(pre_image.shape[0]):   # pre_image的shape的第一个维度就是个数,这里是6
        plt.subplot(1,6,i+1) # 几行几列的 第i+1个图片(从1开始)
        plt.imshow((pre_image[i,:,:,:] + 1)/2) # 加1除2: 将生成的-1~1的图片弄到0-1之间,
        plt.axis('off') # 不要坐标
    plt.show()

2.8 开始训练

训练30000个batch,每个batch随机拿出300个图片用于训练。

batch = 30000
batch_size = 300
train(batch,batch_size)

2.9 loss与acc绘图

损失Loss:

plt.plot(range(1, batch+1), g_loss_list, label='g_loss')
plt.plot(range(1, batch+1), d_loss_list, label='d_loss')
plt.legend()

准确率Acc:

plt.plot(range(1, batch+1), g_acc_list, label='g_acc')
plt.plot(range(1, batch+1), d_acc_list, label='d_acc')
plt.legend()

2.10 结果

可以看到生成器生成图片的效果越来越好

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

loss:

在这里插入图片描述

acc:

更新GAN的另一种实现方法:使用TensorFlow2中求导机制进行自定义训练的GAN代码实现,可对比进行学习。
博客链接:【生成对抗网络】GAN入门与代码实现(二)

参考文献:《TensorFlow2实战》艾力

生成对抗网络GAN)是深度学习领域的一项重要技术,利用GAN可以有效地生成复杂的样本数据,例如图像、音频等。在本文中将介绍如何用pytorch搭建GAN,并对其进行详细的解释。 GAN网络由生成器和判别器两部分组成。生成器接受随机噪声作为输入,通过反向传递训练来生成逼真的样本,而判别器则负责对输入样本进行判断,判断其是否是真实样本。两部分交替训练,并不断优化生成器和判别器的参数,最终可以得到生成器生成逼真样本的能力。 搭建GAN需要先定义生成器和判别器的网络结构,其中生成器可以使用反卷积,而判别器可以使用卷积神经网络。此外,在搭建过程中还需要定义一些超参数,如学习率、训练轮数等。 在开始训练GAN之前,需要先准备好数据集,并对其进行预处理,例如归一化、降噪等。然后对生成器和判别器设置优化器,并开始训练。在训练过程中需要注意调整超参数以达到更好的效果。 最后,在训练结束后需要对GAN进行评估,可以通过计算生成样本与真实样本之间的差别来确定生成器的性能并对其进行改进。 总之,利用pytorch搭建入门GAN需要先定义网络结构和超参数,并使用适当的优化器进行训练,最终可以生成逼真的样本。同时,需要注意调整超参数以达到更好的效果,并对GAN进行评估和改进。
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值