用PaddlePaddle(飞桨)实现minist数据集的GAN生成

任务描述

GAN全称是 Generative Adversarial Network,即生成对抗网络。在14年被Goodfellow等提出后即热度不断一经推出便引爆全场,此后各种花式变体DCGAN、WGAN、CGAN、CYCLEGAN、STARGAN、LSGAN等层出不穷,在“换脸”、“换衣”、“换天地”等应用场景下生成的图像、视频以假乱真,好不热闹。

生成对抗网络一般由一个生成器(生成网络),和一个判别器(判别网络)组成。

生成器的作用是,通过学习训练集数据的特征,在判别器的指导下,将随机噪声分布尽量拟合为训练数据的真实分布,从而生成具有训练集特征的相似数据。而判别器则负责区分输入的数据是真实的还是生成器生成的假数据,并反馈给生成器。两个网络交替训练,能力同步提高,直到生成网络生成的数据能够以假乱真,并与与判别网络的能力达到一定均衡。

数据准备

训练集数据使用飞桨框架内置函数paddle.dataset.mnist.train()、paddle.reader.shuffle()和paddle.batch()进行读取、打乱和划分batch。读取图片数据处理为 [N,W,H] 格式。

要喂入生成器高斯分布的噪声隐变量z的维度设置为100。

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Pool2D, Linear
import numpy as np
import matplotlib.pyplot as plt

# 噪声维度
Z_DIM = 100
BATCH_SIZE = 128
# 读取真实图片的数据集,这里去除了数据集中的label数据,因为label在这里使用不上,这里不考虑标签分类问题。
def mnist_reader(reader):
    def r():
        for img, label in reader():
            yield img.reshape(1, 28, 28)
    return r

# 噪声生成,通过由噪声来生成假的图片数据输入。
def z_reader():
    while True:
        yield np.random.normal(0.0, 1.0, (Z_DIM, 1, 1)).astype('float32')                #正态分布,正态分布的均值、标准差、参数

# 生成真实图片reader
mnist_generator = paddle.batch(
        paddle.reader.shuffle(mnist_reader(paddle.dataset.mnist.train()), 30000),
        batch_size=BATCH_SIZE)

# 生成假图片的reader
z_generator = paddle.batch(z_reader, batch_size=BATCH_SIZE)

测试下数据读取器和高斯噪声生成器。

import matplotlib.pyplot as plt
%matplotlib inline

pics_tmp = next(mnist_generator())
print('一个batch图片数据的形状:batch_size =', len(pics_tmp), ', data_shape =', pics_tmp[0].shape)

plt.imshow(pics_tmp[0][0]) # (28,28)
plt.show

输出: 一个batch图片数据的形状:batch_size = 128 , data_shape = (1, 28, 28) <function matplotlib.pyplot.show(*args, **kw)>

 

z_tmp = next(z_generator())
print('一个batch噪声z的形状:batch_size =', len(z_tmp), ', data_shape =', z_tmp[0].shape)
plt.imshow(z_tmp[0][0]) # (28,28)
plt.show

输出:一个batch噪声z的形状:batch_size = 128 , data_shape = (100, 1, 1)

GAN网络

GAN性能的提升从生成器G和判别器D进行左右互搏、交替完善的过程得到的。所以其G网络和D网络的能力应该设计得相近,复杂度也差不多。这个项目中的生成器,采用了两个全链接层接两组上采样和转置卷积层,将输入的噪声Z逐渐转化为1×28×28的单通道图片输出。

生成器结构:

判别器的结构正好相反,先通过两组卷积和池化层将输入的图片转化为越来越小的特征图,再经过两层全链接层,输出图片是真是假的二分类结果。

判别器结构:

 

# 通过上采样扩大特征图
class G(fluid.dygraph.Layer):
    def __init__(self, name_scope):
        super(G, self).__init__(name_scope)
        name_scope = self.full_name()
        #
        # My_G的代码
        #
        self.fc1=Linear(100,1024)
        self.bn1=fluid.dygraph.BatchNorm(num_channels=1024,act='leaky_relu')
        self.fc2=Linear(input_dim=1024,output_dim=128*7*7)
        self.bn2=fluid.dygraph.BatchNorm(num_channels=128*7*7,act='leaky_relu')
        self.conv1=Conv2D(num_channels=128,num_filters=64,filter_size=3,stride=1,padding=1)
        self.bn3=fluid.dygraph.BatchNorm(num_channels=64,act='leaky_relu')
        self.conv2=Conv2D(num_channels=64,num_filters=1,filter_size=3,stride=1,padding=1)
        self.bn4=fluid.dygraph.BatchNorm(num_channels=1,act='tanh')
        
 def forward(self, z):
        #
        # My_G forward的代码
        z=fluid.layers.reshape(z,shape=[-1,100])
        y=self.fc1(z)
        y=self.bn1(y)
        y=self.fc2(y)
        y=self.bn2(y)
        y=fluid.layers.reshape(y,shape=[-1,128,7,7])
        y=fluid.layers.image_resize(y,scale=2)
        y=self.conv1(y)
        y=self.bn3(y)
        y=fluid.layers.image_resize(y,scale=2)
        y=self.conv2(y)
        y=self.bn4(y)
        return y

class D(fluid.dygraph.Layer):
    def __init__(self, name_scope):
        super(D, self).__init__(name_scope)
        name_scope = self.full_name()
        #
        # My_D的代码
        self.conv1=Conv2D(num_channels=1,num_filters=64,filter_size=3)
        self.bn1=fluid.dygraph.BatchNorm(num_channels=64,act='leaky_relu')
        self.pool1=Pool2D(pool_size=2,pool_stride=2)
        self.conv2=Conv2D(num_channels=64,num_filters=128,filter_size=3)
        self.bn2=fluid.dygraph.BatchNorm(num_channels=128,act='leaky_relu')
        self.pool2=Pool2D(pool_size=2,pool_stride=2)
        self.fc1=Linear(input_dim=128*5*5,output_dim=1024)
        self.bnfc1=fluid.dygraph.BatchNorm(num_channels=1024,act='leaky_relu')
        self.fc2=Linear(input_dim=1024,output_dim=1)
    def forward(self, img):
        #
        # My_G forward的代码
        y=self.conv1(img)
        y=self.bn1(y)
        y=self.pool1(y)
        y=self.conv2(y)
        y=self.bn2(y)
        y=self.pool2(y)
        y=fluid.layers.reshape(y,shape=[-1,128*5*5])
        y=self.fc1(y)
        y=self.bnfc1(y)
        y=self.fc2(y)
        return y

测试生成器G网络和判别器D网络的前向计算结果。一个batch的数据,输出一张图片。

# 测试生成网络G和判别网络D
with fluid.dygraph.guard():
    g_tmp = G('G')
    tmp_g = g_tmp(fluid.dygraph.to_variable(np.array(z_tmp))).numpy()
    print('生成器G生成图片数据的形状:', tmp_g.shape)
    plt.imshow(tmp_g[0][0])
    plt.show()
    
    d_tmp = D('D')
    tmp_d = d_tmp(fluid.dygraph.to_variable(tmp_g)).numpy()
    print('判别器D判别生成的图片的概率数据形状:', tmp_d.shape)
    print(max(tmp_d))
# 显示图片,构建一个16*n大小(n=batch_size/16)的图片阵列,把预测的图片打印到note中。
import matplotlib.pyplot as plt
%matplotlib inline

def show_image_grid(images, batch_size=128, pass_id=None):
    fig = plt.figure(figsize=(8, batch_size/32))
    fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/16), 16)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(image[0], cmap='Greys_r')    
    plt.show()

show_image_grid(tmp_g, BATCH_SIZE)

网络训练

网络的训练优化目标就是如下公式:

公式出自Goodfellow在2014年发表的论文Generative Adversarial Nets。 这里简单介绍下公式的含义和如何应用到代码中。上式中等号左边的部分:

表示的是生成样本和真实样本的差异度,可以使用二分类(真、假两个类别)的交叉商损失。

表示在生成器固定的情况下,通过最大化交叉商损失来更新判别器D的参数。

表示生成器要在判别器最大化真、假图片交叉商损失的情况下,最小化这个交叉商损失。

等式的右边其实就是将等式左边的交叉商损失公式展开,并写成概率分布的期望形式。详细的推导请参见原论文《Generative Adversarial Nets》。

下面是训练模型的代码,有详细的注释。大致过程是:先用真图片训练一次判别器d的参数,再用生成器g生成的假图片训练一次判别器d的参数,最后用判别器d判断生成器g生成的假图片的概率值更新一次生成器g的参数,即每轮训练先训练两次判别器d,再训练一次生成器g,使得判别器d的能力始终稍稍高于生成器g一些。

def train(mnist_generator, epoch_num=1, batch_size=128, use_gpu=True, load_model=False):
    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        # 模型存储路径
        model_path = './output/'
        d = D('D')
        d.train()
        g = G('G')
        g.train()
        # 创建优化方法
        real_d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=5e-4, parameter_list=d.parameters())
        fake_d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=5e-4, parameter_list=d.parameters())
        g_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=5e-4, parameter_list=g.parameters())
        
        # 读取上次保存的模型
        if load_model == True:
            g_para, g_opt = fluid.load_dygraph(model_path+'g')
            d_para, d_r_opt = fluid.load_dygraph(model_path+'d_o_r')
            # 上面判别器的参数已经读取到d_para了,此处无需再次读取
             _, d_f_opt = fluid.load_dygraph(model_path+'d_o_f')
            g.load_dict(g_para)
            g_optimizer.set_dict(g_opt)
            d.load_dict(d_para)
            real_d_optimizer.set_dict(d_r_opt)
            fake_d_optimizer.set_dict(d_f_opt)

        iteration_num = 0
        for epoch in range(epoch_num):
            for i, real_image in enumerate(mnist_generator()):
                # 丢弃不满整个batch_size的数据
                if(len(real_image) != BATCH_SIZE):
                    continue               
                iteration_num += 1                
                '''
                判别器d通过最小化输入真实图片时判别器d的输出与真值标签ones的交叉熵损失,来优化判别器的参数,
                以增加判别器d识别真实图片real_image为真值标签ones的概率。
                '''
                # 将MNIST数据集里的图片读入real_image,将真值标签ones用数字1初始化
                real_image = fluid.dygraph.to_variable(np.array(real_image))
                ones = fluid.dygraph.to_variable(np.ones([len(real_image), 1]).astype('float32'))
                 # 计算判别器d判断真实图片的概率
                p_real = d(real_image)
                # 计算判别真图片为真的损失
                real_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_real, ones)
                real_avg_cost = fluid.layers.mean(real_cost)
                # 反向传播更新判别器d的参数
                real_avg_cost.backward()
                real_d_optimizer.minimize(real_avg_cost)
                d.clear_gradients()
                
                '''
                判别器d通过最小化输入生成器g生成的假图片g(z)时判别器的输出与假值标签zeros的交叉熵损失,
                来优化判别器d的参数,以增加判别器d识别生成器g生成的假图片g(z)为假值标签zeros的概率。
                '''
                # 创建高斯分布的噪声z,将假值标签zeros初始化为0
                z = next(z_generator())
                z = fluid.dygraph.to_variable(np.array(z))
                zeros = fluid.dygraph.to_variable(np.zeros([len(real_image), 1]).astype('float32'))
                # 判别器d判断生成器g生成的假图片的概率
                p_fake = d(g(z))
                # 计算判别生成器g生成的假图片为假的损失
                fake_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_fake, zeros)
                fake_avg_cost = fluid.layers.mean(fake_cost)
                # 反向传播更新判别器d的参数
                fake_avg_cost.backward()
                fake_d_optimizer.minimize(fake_avg_cost)
                d.clear_gradients()

                '''
                生成器g通过最小化判别器d判别生成器生成的假图片g(z)为真的概率d(fake)与真值标签ones的交叉熵损失,
                来优化生成器g的参数,以增加生成器g使判别器d判别其生成的假图片g(z)为真值标签ones的概率。
                '''
                # 生成器用输入的高斯噪声z生成假图片
                fake = g(z)
                # 计算判别器d判断生成器g生成的假图片的概率
                p_confused = d(fake)
                # 使用判别器d判断生成器g生成的假图片的概率与真值ones的交叉熵计算损失
                g_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_confused, ones)
                g_avg_cost = fluid.layers.mean(g_cost)
                # 反向传播更新生成器g的参数
                g_avg_cost.backward()
                g_optimizer.minimize(g_avg_cost)
                g.clear_gradients()
                
                # 打印输出
                if(iteration_num % 100 == 0):
                    print('epoch =', epoch, ', batch =', i, ', real_d_loss =', real_avg_cost.numpy(),
                     ', fake_d_loss =', fake_avg_cost.numpy(), 'g_loss =', g_avg_cost.numpy())
                    show_image_grid(fake.numpy(), BATCH_SIZE, epoch)                             
        
        # 存储模型
        fluid.save_dygraph(g.state_dict(), model_path+'g')
        fluid.save_dygraph(g_optimizer.state_dict(), model_path+'g')
        fluid.save_dygraph(d.state_dict(), model_path+'d_o_r')
        fluid.save_dygraph(real_d_optimizer.state_dict(), model_path+'d_o_r')
        fluid.save_dygraph(d.state_dict(), model_path+'d_o_f')
        fluid.save_dygraph(fake_d_optimizer.state_dict(), model_path+'d_o_f')

train(mnist_generator, epoch_num=100, batch_size=BATCH_SIZE, use_gpu=True) # 10

最终经过epoch=100轮时,gan输出图片如下

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小小谢先生

支持知识付费

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值