生成型对抗性网络介绍与实现原理

如何无中生有是AI领域研究的重点。原有神经网络大多是对已有问题的识别和研究,例如让神经网络学会识别图片中的动物是猫还是狗,随着研究的进一步深入,目前能够做到让网络不但能识别图片中的物体,还能让它学会如何创造图片中的物体,具备”创造性“让AI技术的应用价值大大提升。

在深度学习“创造性”上做出巨大贡献的是来自谷歌大脑项目组的研究员Goodfellow提出一篇名为”Generative Adversarial Networks”的学术报告,他提出一种特别的网络结构,网络由两部分组成,这两部分形成一种对抗关系,一部分叫Generator,由它进行”创造“,例如给出一幅画像或某种数据,另一部分叫discrimator,它的任务是对前者的”创造“进行检验,前者的任务是创造出质量足够好的输出以便通过后者的检验,后者的任务是不断提高自己的检验能力以便识别出给定的数据是真是假。

我们通过一个形象的比喻来理解该网络。设想一个艺术家他擅长伪造毕加索的画,一个是鉴定家,他能识别出真迹和假迹,艺术家将自己伪造的画作交给鉴定家识别,如果被识别出,他就总结经验改进伪造能力,他不断的提升伪造能力,直到鉴定家鉴定不出为止,这样的话艺术家的作画能力就与毕加索没有任何区别。

我们先通过代码实践获得初步感受,以便为理论把握打下良好基础,我们先构造一个能学会如何”作画“的神经网络。我们使用谷歌提供的数据集做网络的训练数据,数据的下载链接为:
https://storage.cloud.google.com/quickdraw_dataset/full/numpy_bitmap/camel.npy
首先我们要把数据加载到内存中,代码如下:

import os
from os import walk
import numpy as np
def  load_data():
    path = "./data"
    txt_name_list = []
    for (dirpath, dirnames, filenames) in walk(path):
        for f in filenames:
            if f != '.DS_Store':
                txt_name_list.append(f)
                break
    slice_train = int(80000/len(txt_name_list))
    i = 0
    seed = np.random.randint(1, 10e6)

    for txt_name in txt_name_list:
        txt_path = os.path.join(path, txt_name)
        x = np.load(txt_path)
        x = (x.astype('float32') - 127.5) / 127.5
        x = x.reshape(x.shape[0], 28, 28, 1)

        y = [i] * len(x)
        np.random.seed(seed)
        np.random.shuffle(x)
        np.random.seed(seed)
        np.random.shuffle(y)
        x = x[:slice_train]
        y = y[:slice_train]
        if i != 0:
            xtotal = np.concatenate((x, xtotal), axis = 0)
            ytotal = np.concatenate((y, ytotal), axis = 0)
        else:
            xtotal = x
            ytotal = y
        i += 1
    return xtotal, ytotal
(x_train, y_train) = load_data()
import matplotlib.pyplot as plt
plt.imshow(x_train[200, :, :, 0], cmap = 'gray')

上面代码运行后所得结果如下:

这个数据集来自于谷歌的Quick,Draw!它是一笔手工画,我们的任务是训练网络,让它能生成类似风格的一笔手工画。接下来我们要构造两个网络,这两个网络性质上属于一阴一阳是一种相互对抗的关系。其中一个网络叫生成者,另一个网络叫做鉴别者,生成者网络的任务是生成尽可能类似上面图像的图片,鉴别者的任务是学会识别上面图像的特点,然后识别输入给它的图片到底是真实图片还是有生成者构造的图片,因此两个网络是互为博弈的关系。一方面算法会训练鉴别者对图片的识别能力越来越强,同时训练生成者生成的图片尽可能通过鉴别者的认定,随着鉴别者识别能力越强,生成者根据鉴别者的反馈调整内部参数,直到它生成的图片不断通过鉴别者识别后,他绘制出来的图片就越来越像真实图片。下图给出了鉴别者与识别者组成的对抗性生成型网络的结构图:

接下来我们看看生成者网络的实现代码:

class Model(tf.keras.Model):
    def  __init__(self):
        super(Model, self).__init__()
        self.layers = []
        self.weight_init = tf.keras.initializers.RandomNormal(mean = 0., stddev = 0.2) #用于初始化网络层参数
    def  get_activation(self, activation): #选定网络层的激活函数
        if  activation == 'leaky_relu':
            layer =  tf.keras.layers.LeakyReLU(alpha = 0.2)()
        else:
            layer = tf.keras.layers.Activation(activation)()
        return layer
     def  call(self, x):
        for layer in self.generator_layers:
            x = layer(x)
        return x

class  Generator(Model):
    def  __init__(self, generator_params):
        super(Generator, self).__init__()
        self.generator_layers = []
        self.weight_init = tf.keras.initializers.RandomNormal
        self.generator_layers.append(tf.keras.layers.Dense(units = generator_params.generator_initial_dense_layer_size,
                                                           kernel_initializer = self.weight_init))
        if  generator_params.generator_batch_norm_momentun:
            self.generator_layers.append(
                tf.keras.layers.BatchNormalization(momentum = self.generator_batch_nrom_momentun)
            )
        self.generator_layers.append(self.get_activation(generator_params.generator_activation))

        self.generator_layers.append(tf.keras.layers.Reshape(generator_params.generator_initial_dense_layer_size))

        if generator_params.generator_dropout_rate:
            self.generator_layers.append(tf.keras.layers.Dropout(generator_params.generator_dropout_rate))
        for i in range(generator_params.n_layers_generator):
            if  generator_params.generator_upsample[i] == 2:
                '''
                UpSampling2D会将像素点在前后左右进行复制,例如:
                Input = [1,2
                         3,4]  经过计算后得:
                output = [1, 1, 2, 2
                          1, 1, 2, 2
                          3, 3, 4, 4,
                          3, 3, 4, 4]
                Conv2DTranspose同样会把输入扩展为原来的2倍,只不过新增加的像素点并不是直接复制而是
                通过训练后网络寻找出最合适的像素点值
                '''
                self.generator_layers.append(tf.keras.UpSampleing2D())
                self.generator_layers.append(
                    tf.keras.layers.Conv2D(filters = generator_params.generator_conv_filters[i],
                                           kernel_size = generator_params.generator_conv_kernel_size[i],
                                           padding = 'same',
                                           name = 'generator_conv_' + str(i),
                                           kernel_initializer = self.weight_init
                                           )
                )
            else:
                self.generator_layers.append(
                    tf.keras.layers.Conv2DTranspose(
                        filters = generator_params.generator_conv_filters[i],
                        kernel_size = generator_params.generator_conv_kernel_size[i],
                        padding = 'same',
                        strides = generator_params.generator_conv_strides[i],
                        name = "generator_conv_" + str(i),
                        kernel_initializer = self.weight_init
                    )
                )
            if  i < generator_params.n_layer_generator - 1:
                if  generator_batch_norm_momentum:
                    self.generator_layers.append(
                        tf.keras.layers.BatchNormalization(momentum = generator_params.generator_batch_norm_momentum)
                    )
                self.generator_layers.append(
                    self.get_activation(generator_params.generator_activation)
                )
            else: #最后生成图像的像素点值在[-1,1]之间后面会进一步把像素点值改为[0,1]之间
                self.generator_layers.append(
                    tf.keras.Activation('tanh')
                )

        self.layers = self.generator_layers

生成者网络会接收含有给定分量的高维关键向量,然后把向量转换为对应图片,后面我们会在标准正太分布中获取向量,这意味着生成者网络的任务其实是学会将标准正太分布中的每个点映射成给定风格的图片。接下来我们看看鉴别者网络的代码实现:

class Discriminator(Model):
    def __init__(self, params):
        super(Discriminator, self).__init__()
        self.dsicriminator_layers = []
        self.weight_init = tf.keras.initializers.RandomNormal
        for i in range(params.n_layers_discriminator):
            self.dsicriminator_layers.append(
                tf.keras.layers.Conv2D(
                    filters = params.discriminator_conv_filters[i],
                    kernel_size = params.discriminator_conv_kernel_size[i],
                    strides = params.discriminator_conv_strides[i],
                    padding = 'same',
                    name = 'discriminator_conv_' + str(i),
                    kernel_initializer = self.weight_init
                )
            )
            if  params.discriminator_batch_norm_momentum and i > 0:
                self.discriminator_layers.append(
                    tf.keras.layers.BatchNormalization(momentum = params.discriminator_batch_norm_momentun)
                )
            self.discriminator_layers.append(self.get_activation(params.discrimator_activation))
            if  params.discriminator_dropout_rate:
                self.discriminator_layers.append(
                    tf.keras.layers.Dropout(rate = params.discriminator_dropout_rate)
                )
        self.discriminator_layers.append(
            tf.keras.layers.Flatten()
        )
        self.discriminator_layers.append(
            tf.keras.layers.Dense(units = 2, activation = 'softmax',
                                  kernel_initializer = self.weight_init)#计算输入数据为真或假的概率
        )
        self.discriminator_layers.append(
            tf.argmax
        )

        self.layers = self.discriminator_layers

鉴别者网络接收图片对应二维数组,然后给出0或1用于表示输入图片是否为真实图片,下面我们把两个网络连接起来进行训练:

class GAN():
     def  __init__(self, discrimiator_params, generator_params, z_dim):
         self.d_losses = []
         self.g_losses = []
         self.epoch = 0
         self.z_dim = z_dim  #关键向量的维度
         #设置生成者和鉴别者网络的优化函数
         self.discriminator_optimizer = tf.train.AdamOptimizer(discriminator_params.learning_rate)
         self.generator_optimizer = tf.train.AdamOptimizer(generator_params.learning_rate)
         generator_params.n_layers_generator = len(generator_params.generator_conv_filters)
         self.generator = Generator(generator_params)
         discriminator_params.n_layers_discriminator = len(discriminator_params.discriminator_conv_filters)
         self.discriminator = Discriminator(discriminator_params)
         self.build_adversarial()

    def  train_discriminator(self, x_train, batch_size, using_generator):
        '''
        训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。
        然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力
        '''
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        if using_generator:#使用数据加载器
            true_imgs = next(x_train)[0] #读入图片数据
            if true_imgs.shape[0] != batch_size:
                true_imgs = next(x_train)[0]
        else:#之间从文件系统读取训练数据
            idx = np.random.randint(0, x_train.shape[0], batch_size)
            true_imgs = x_train[idx]
        noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        gen_imgs = self.generator(noise) #让生成者网络根据关键向量生成图片
        real_accuracy = tfe.metrics.Accuracy()  #计算网络识别正确图片的成功率
        fake_accuracy = tfe.metrics.Accuracy() #计算网络识别构造图片的成功率

        with tf.GradientTape(watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数
            tape.watch(self.discriminator.trainable_variables)
            d_loss_real = tf.keras.losses.BinaryCrossentropy(y_true = valid, self.discriminator(true_imgs))
        grads = tape.gradients(d_loss_real, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainbale_variables)) #改进鉴别者网络内部参数
        d_acc_real = real_accuracy(labels = valid, predictions = self.discriminator(true_imgs)).result().numpy() #计算训练后网络识别真图片的正确率

        with tf.GradientTape(watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数
            tape.watch(self.discriminator.trainable_variables)
            d_loss_fake = tf.keras.losses.BinaryCrossentropy(y_true = fake,
                                                             y_pred = self.discriminator(gen_imgs))
        grads = tape.gradients(d_loss_fake, self.discriminator.trainbale_variables)
        self.discriminator_optimizer.apply_gradients(ziep(grads, self.discriminator.trainable_variables))
        d_acc_fake = fake_accuracy(labels = fake, predictions = self.discriminator(gen_imgs)).result().numpy() #计算鉴别者网络识别虚假图片的正确率

        d_loss = 0.5*(d_loss_real + d_loss_fake)
        d_acc = 0.5 * (d_acc_fake + d_acc_real)

        return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]

    def  train_generator(self, batch_size): #训练生成者网络
        '''
        生成者网络训练的目的是让它生成的图像尽可能通过鉴别者网络的审查
        '''
        valid = np.ones((batch_size, 1)) #希望生成的图片尽可能多通过鉴别者网络的鉴定
        noise = np.random.normal(0, 1, (batch_size, self.z_dim)) #随机生成关键向量
        with tf.GradientTape(watch_accessed_variables=False) as tape: #只能修改生成者网络的内部参数不能修改鉴别者网络的内部参数
            tape.watch(self.generator.trainbale_variables)
            gen_imgs = self.generator(noise) #生成伪造的图片
            verify_results = self.discriminator(gen_imgs)
            verify_loss = tf.keras.keras.BinaryCrossentropy(y_true = valid,
                                                            y_pred = verify_results)
        grads = tape.gradients(verify_loss, self.generator.trainable_variables) #调整生成者网络内部参数使得它生成的图片尽可能通过鉴别者网络的识别
        self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))
        pass_accuracy = tfe.metrics.Accuracy() #计算生成者网络能通鉴定的成功率
        gen_imgs = self.generator(noise)
        verify_results = self.discriminator(gne_imgs)
        accuracy = pass_accuracy(labels = valid, predictions = verify_results).result().numpy()#检验训练后生成者网络的成功率
        return verify_loss, accuracy
    def  train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 50,
               using_generator = False):#启动训练流程
         for  epoch in range(self.epoch, self.epoch + epoches):
             d = self.train_discriminator(x_train, batch_size, using_generator)
             g = self.train_generator(batch_size)
             print("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" %
                   (epoch, d[0], d[1], d[2], d[3], d[4], g[0], g[1]))
             if epoch % print_every_n_batches == 0:
                 self.sample_images(run_folder) #将生成者构造的图像绘制出来
                 self.save_model(run_folder) #存储两个网络的内部参数
            self.epoch + 1
    def  sample_images(self, run_folder): #绘制生成者构建的图像
        r, c = 5,5
        noise = np.random.normal(0, 1, (r * c, self.z_dim)) #构建关键向量
        gen_imgs = self.generator(noise)
        gen_imgs = 0.5 * (gen_imgs + 1)
        gen_imgs = np.clip(gen_imgs, 0, 1) #将图片像素点转换到[0,1]之间
        fig, axs = plt.subplots(r, c, figsize = [15, 15])
        cnt = 0

        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(run_folder, 'images/sample%d.png' % self.epoch))
        plt.close()
    def  save_model(self, run_folder): #保持网络内部参数
        self.discriminator.save_weights(os.path.join(run_folder, 'discriminator.h5'))
        self.generator.save_weights(os.path.join(run_folder, 'generator.h5'))
    def  load_model(self, run_folder):
        self.discriminator.load_weights(os.path.join(run_folder, 'discriminator.h5'))
        self.generator.load_weights(os.path.join(run_folder, 'generator.h5'))

从代码上看,两个网络的训练流程不同。鉴别者网络需要输入两种图片,一种是真实图片,它要学会识别图片特征把握真实图片的特性,另一种是生成者网络构造的图片,它要训练内部参数,使得将生成者网络构造的图片识别为假。生成者网络的训练比较特别,它必须结合鉴别者网络才能进行训练,首先它接收来自正太分布的关键向量,然后输出一幅图像,接着把图像输入到鉴别者网络,根据后者给出的结果调整内部参数,它调整参数的结果是尽可能让鉴别者网络输出结果为真,要注意的是在训练生成者网络时,一定有要冻结鉴别者网络的内部参数,如果在训练生成者时也改变了鉴别者的内部参数,那么整个训练将无法收敛。

由于对抗性生成型网络的训练过程较为复杂,有必要单独着重讲解,因此本节只给出网络的实现代码,下一节我们看看如何有效的训练两个网络。

更多精彩内容和详细讲解请点击’阅读原文‘

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值