[TensorFlow]生成对抗网络(GAN)介绍与实践

主旨

本文简要介绍了生成对抗网络(GAN)的原理,接下来通过tensorflow开发程序实现生成对抗网络(GAN),并且通过实现的GAN完成对等差数列的生成和识别。通过对设计思路和实现方案的介绍,本文可以辅助读者理解GAN的工作原理,并掌握实现方法。有了这样的基础,在面对工作中实际问题时可以将GAN纳入考虑,选择最合适的算法

代码和运行环境

代码位置: 
wangyaobupt/GAN

TensorFlow版本

>>> tf.version 
‘1.1.0-rc2’

背景知识

Generative Adversarial Nets[1][https://arxiv.org/pdf/1406.2661v1.pdf]是Ian J. Goodfellow等在2014年提出的一种训练模型的方法,此方法通过两个网络(生成网络G和分类网络D)对抗训练,得到符合预期目标的生成模型和分类模型。

要理解GAN的原理,上述论文是最好的教材。但考虑到原文首先是英文撰写,其次包含不少数学推导,新手上手并不容易。因此笔者这里班门弄斧,基于论文简单转述GAN的设计思想要点

GAN的目标,给定一个真实样本(本文也称之为ground truth)集合,训练出两个模型,一个能够从噪声信号生成尽可能像ground truth的样本;另一个能够判断给定样本是否是ground truth。两个模型详细介绍如下

  • 生成模型:论文中称为generative model,本文称为G网络或G模型。G网络的输入是噪声信号(例如均匀分布的随机数),输出为形状与真实样本ground truth一致。G网络的训练目标是,尽可能输出与ground truth相似的样本。这里“相似”定义为:如果G网络生成的一个样本骗过了D网络,使得D网络误以为这就是真实样本,则就是相似的,G网络获得奖励;反之,获得惩罚。
  • 分类模型:论文中称为discriminative model,本文称为D网络或D模型。D网络是一个2分类器,输入为ground truth或者G网络生成的样本,输出为TRUE或FALSE:TRUE表示D网络认为当前输入样本是ground truth,FALSE表示D网络认为当前输入样本是G网络生成的“伪造”样本。D网络的训练目标是尽可能正确的区分开ground truth和G网络生成的“伪造”样本。

从上述讨论可以看出,G网络和D网络是两个目标完全相反的网络,G网络尽其所能“伪造”出像真实样本的数据,D网络尽可能区分真实与伪造数据。GAN中所谓“对抗”的概念,即来源于此。

GAN的训练过程就是G和D两个网络互相对抗的过程,对抗的结果是G网络被训练到能够生成以假乱真的样本,即G网络从噪声输入得到了尽可能与真实样本相似的输出,或者说G学会了从噪声生成ground truth的方法;D网络也可以区分ground truth与其他样本,即D学会了区分ground truth与其他数据的方法。

参考文献 
1. Goodfellow I J, Pougetabadie J, Mirza M, et al. Generative adversarial nets[C]. neural information processing systems, 2014: 2672-2680.

神经网络设计和实现

问题构造

在开始设计神经网络之前,我们首先构造出预期GAN解决的问题。前述GAN论文中提出了一个从噪声学习正态分布的经典问题,读者如果在网络上搜索GAN的案例,除了图像识别,基本上只有这么一个问题和方案实现。

本文重新设计了一个与论文中不同的问题。问题描述如下

  • Ground Truth定义:[1,2,3,4,5,6,7,8,9,10]构成的等差数列,为了适当降低学习难度,此数列每个元素与噪声相加,噪声为0均值正态分布随机变量,标准差取0.1, 0.03, 0等不同数值
  • 输入噪声定义: [-1,1]之间均匀分布的随机变量。

网络结构设计

G网络:参考论文资料,我们选择多层全连接神经网络

D网络:由于要分辨的是等差数列,我们选择RNN作为D网络。

网络结构如下(下图是tensorboard生成的计算图):图中”G_net”表示G网络,”D_net”/”D_net_1”表示D网络,虽然图中D网络被分成了两份,但是其RNN参数是共享的,即图中正下方”rnn”这个单元。


代码实现

G网络定义

   # generative network
    # use multi-layer percepton to generate time sequence from random noise
    # input tensor must be in shape of (batch_size, self.seq_len)
    def generator(self, inputTensor):
        with tf.name_scope('G_net'):
            gInputTensor = tf.identity(inputTensor, name='input')
            # Multilayer percepton implementation
            numNodesInEachLayer = 10
            numLayers = 3 

            previous_output_tensor = gInputTensor
            for layerIdx in range(numLayers):
                activation,z = self.fullConnectedLayer(previous_output_tensor, numNodesInEachLayer, layerIdx)
                previous_output_tensor = activation

            g_logit = z
            g_logit = tf.identity(g_logit, 'g_logit')
            return g_logit

G网络损失函数 
下面代码片段中self.d_logit_fake是D网络对G网络生成数据的判定结果。由于G网络的目标是尽可能骗过D网路,如果D网络对于G网络生成数据全部判为1(即TRUE),则损失最小,反之,损失最大。

g_loss_d = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_fake,
                    labels=tf.ones(shape=[self.batch_size_t,1])
                    ),
                name='g_loss_d'
                )

D网络的定义 
RNN+全连接输出层,无论是RNN还是全连接层都必须在对ground truth和G生成样本之间共享同一套参数

def discriminator(self, inputTensor,reuseParam):
        with tf.name_scope('D_net'):
            num_units_in_LSTMCell = 10

            # RNN definition
            with tf.variable_scope('d_rnn'):
                lstmCell = tf.contrib.rnn.BasicLSTMCell(num_units_in_LSTMCell,reuse=reuseParam)
                init_state = lstmCell.zero_state(self.batch_size_t, dtype=tf.float32)
                raw_output, final_state = tf.nn.dynamic_rnn(lstmCell, inputTensor, initial_state=init_state)

            rnn_output_list = tf.unstack(tf.transpose(raw_output, [1, 0, 2]), name='outList')
            rnn_output_tensor = rnn_output_list[-1];

            # Full connected network
            numberOfInputDims = inputTensor.shape[1].value
            numOfNodesInLayer = 1
            if not reuseParam:
                self.d_w = tf.Variable(initial_value=tf.random_normal([numberOfInputDims, numOfNodesInLayer]),
                        name=('dnet_w_1'))
                self.d_b = tf.Variable(tf.zeros([1, numOfNodesInLayer]), name='dnet_b_1')
            self.d_z = tf.matmul(rnn_output_tensor,self.d_w) + self.d_b
            self.d_z = tf.identity(self.d_z, name='dnet_z_1')
            d_sigmoid = tf.nn.sigmoid(self.d_z, name='dnet_a_1')

            d_logit = self.d_z
            d_logit = tf.identity(d_logit, 'd_net_logit')
            return d_logit

D网络损失函数 
D网络使用同一套参数分辨两种输入,一种是ground truth,另一种是G网络的输出。对于ground truth,训练目标为尽可能判为1,对于G网络的输出,训练目标为尽可能判为0,因此Loss函数定义如下

# For D-network, jduge ground truth to TRUE, jduge G-network output to FALSE,making loss low
            d_loss_ground_truth = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_gnd_truth,
                    labels=tf.ones(shape=[self.batch_size_t,1])
                    ),
                name='d_loss_gnd'
                )

            d_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.d_logit_fake,
                    labels=tf.zeros(shape=[self.batch_size_t,1])
                    ),
                name='d_loss_fake'
                )

            d_loss = d_loss_ground_truth + d_loss_fake

对抗训练 
对抗训练中,G网络Loss值只用来调整G网络参数,D网络Loss值只用来调整D网络参数

       g_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_net')
        g_net_var_list = g_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_rnn')
        self.train_g = tf.train.AdamOptimizer(self.lr_g).minimize(g_loss,var_list=g_net_var_list)

        d_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D_net')
        d_net_var_list = d_net_var_list +  tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d_rnn')
        self.train_d = tf.train.AdamOptimizer(self.lr_d).minimize(d_loss,var_list=d_net_var_list)

训练效果

下图是训练过程中D网络对ground truth和G网络输出的分类正确率曲线 

从图中可以看到3个阶段

  1. 训练开始后一秒左右:“D网络对ground truth的分类正确率”和“D网络对G网络输出的分类正确率”都快速上升到100%,即D网络经过训练可以完全正确的将真的判为真,假的判为假
  2. 训练后1-15s:D网络分类正确率保持全对
  3. 15s之后:“D网络对ground truth的分类正确率”和“D网络对G网络输出的分类正确率”出现震荡,表明在这个阶段G网络已经能够以假乱真,D网络将部分G网络输出判为真,同时也将部分ground truth判为假。

上述3个阶段就体现出对抗训练的特点,G网络和D网络互为对手,互相提高对方的训练难度,最终得到符合预期的模型。

接下来再从数据上给一个直观的认识

Ground truth: 在公差为1的等差数列上加入stddev=0.3, mean=0的正态分布噪声后,得到的一组Ground Truth数据如下

[ 1.1539436 ] 
[ 2.08863655] 
[ 2.78491645] 
[ 3.93027817] 
[ 4.75851967] 
[ 5.88655699] 
[ 7.10540526] 
[ 7.43159023] 
[ 9.19373617] 
[ 10.08779359]

训练开始前G网络的数据

基本无规律,和输入噪声分布接近

[ 1.15080559] 
[ 0.66351247] 
[-0.39484465] 
[-0.41690648] 
[ 0.29061955] 
[ 0.06131642] 
[-2.46439648] 
[-1.53692639] 
[-0.30550677] 
[-0.89200932]

迭代100次之后G网络的输出 
出现等差数列的端倪

[ -0.53692651] 
[ 0.86063552] 
[ 2.47294378] 
[ 5.24512053] 
[ 7.7618413 ] 
[ 9.57867622] 
[ 9.15039253] 
[ 9.86567402] 
[ 10.62975025] 
[ 10.24322414]

迭代500次之后G网络的输出 
除了最后一个元素,前9个元素已经基本符合预期

[ 1.09549832] 
[ 2.21490908] 
[ 2.95311546] 
[ 4.06684017] 
[ 4.96308947] 
[ 6.03393888] 
[ 6.89026165] 
[ 7.93375683] 
[ 8.63552094] 
[ 9.07077026]

迭代1500次之后G网络的输出 
已经足以以假乱真

[ 0.07186054] 
[ 1.08289695] 
[ 2.55904818] 
[ 4.07374573] 
[ 5.14763832] 
[ 6.07010031] 
[ 6.79585028] 
[ 8.17086124] 
[ 8.81297684] 
[ 10.38190079]

更多资料

本文首发于:王尧的技术博客。博客中的内容体系性不如在知乎整理的清楚,但会随时记录工作中的技术问题和发现,如有兴趣欢迎围观。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值