对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析

1、背景GAN作为生成模型的一种新型训练方法,通过discriminative model来指导generative model的训练,并在真实数据中取得了很好的效果。尽管如此,当目标是一个待生成的非连续性序列时,该方法就会表现出其局限性。非连续性序列生成,比如说文本生成,为什么单纯的使用GAN没有取得很好的效果呢?主要的屏障有两点:1)在GAN中,...
摘要由CSDN通过智能技术生成

1、背景

GAN作为生成模型的一种新型训练方法,通过discriminative model来指导generative model的训练,并在真实数据中取得了很好的效果。尽管如此,当目标是一个待生成的非连续性序列时,该方法就会表现出其局限性。非连续性序列生成,比如说文本生成,为什么单纯的使用GAN没有取得很好的效果呢?主要的屏障有两点:

1)在GAN中,Generator是通过随机抽样作为开始,然后根据模型的参数进行确定性的转化。通过generative model G的输出,discriminative model D计算的损失值,根据得到的损失梯度去指导generative model G做轻微改变,从而使G产生更加真实的数据。而在文本生成任务中,G通常使用的是LSTM,那么G传递给D的是一堆离散值序列,即每一个LSTM单元的输出经过softmax之后再取argmax或者基于概率采样得到一个具体的单词,那么这使得梯度下架很难处理。

2)GAN只能评估出整个生成序列的score/loss,不能够细化到去评估当前生成token的好坏和对后面生成的影响。

强化学习可以很好的解决上述的两点。再回想一下Policy Gradient的基本思想,即通过reward作为反馈,增加得到reward大的动作出现的概率,减小reward小的动作出现的概率,如果我们有了reward,就可以进行梯度训练,更新参数。如果使用Policy Gradient的算法,当G产生一个单词时,如果我们能够得到一个反馈的Reward,就能通过这个reward来更新G的参数,而不再需要依赖于D的反向传播来更新参数,因此较好的解决了上面所说的第一个屏障。对于第二个屏障,当产生一个单词时,我们可以使用蒙塔卡罗树搜索(Alpho Go也运用了此方法)立即评估当前单词的好坏,而不需要等到整个序列结束再来评价这个单词的好坏。

因此,强化学习和对抗思想的结合,理论上可以解决非连续序列生成的问题,而SeqGAN模型,正是这两种思想碰撞而产生的可用于文本序列生成的模型。

SeqGAN模型的原文地址为:https://arxiv.org/abs/1609.05473,当然在我的github链接中已经把下载好的原文贴进去啦。

结合代码可以更好的理解模型的细节哟:https://github.com/princewen/tensorflow_practice/tree/master/seqgan

2、SeqGAN的原理

SeqGAN的全称是Sequence Generative Adversarial Nets。这里打公式太麻烦了,所以我们用word打好再粘过来,冲这波手打也要给小编一个赞呀,哈哈!

整体流程

8ac3e7e7deed3e4596b9966e039f618fe83b8939

模型的示意图如下:

4f784d14c59745660104aa41322b08c2189cb377

Generator模型和训练

接下来,我们分别来说一下Generator模型和Discriminator模型结构。

Generator一般选择的是循环神经网络结构,RNN,LSTM或者是GRU都可以。对于输入的序列,我们首先得到序列中单词的embedding,然后输入每个cell中,并结合一层全链接隐藏层得到输出每个单词的概率,即:

58fe04e1e6c0ebbb3362d480eb0d0373ea03c59e

有了这个概率,Generator可以根据它采样一批产生的序列,比如我们生成一个只有,两个单词的序列,总共的单词序列有3个,第一个cell的输出为(0.5,0.5,0.0),第二个cell的输出为(0.1,0.8,0.1),那么Generator产生的序列以0.4的概率是1->2,以0.05的概率是1->1。注意这里Generator产生的序列是概率采样得到的,而不是对每个输出进行argmax得到的固定的值。这和policy gradient的思想是一致的。

在每一个cell我们都能得到一个概率分布,我们基于它选择了一个动作或者说一个单词,如何判定基于这个概率分布得到的单词的还是坏的呢?即我们需要一个reward来左右这个单词被选择的概率。这个reward怎么得到呢,就需要我们的Discriminator以及蒙塔卡罗树搜索方法了。前面提到过Reward的计算依据是最大可能的Discriminator,即尽可能的让Discriminator认为Generator产生的数据为real-world的数据。这里我们设定real-world的数据的label为1,而Generator产生的数据label为0.

如果当前的cell是最后的一个cell,即我们已经得到了一个完整的序列,那么此时很好办,直接把这个序列扔给Discriminator,得到输出为1的概率就可以得到reward值。如果当前的cell不是最后一个cell,即当前的单词不是最后的单词,我们还没有得到一个完整的序列,如何估计当前这个单词的reward呢?我们用到了蒙特卡罗树搜索的方法。即使用前面已经产生的序列,从当前位置的下一个位置开始采样,得到一堆完整的序列。在原文中,采样策略被称为roll-out policy,这个策略也是通过一个神经网络实现,这个神经网络我们可以认为就是我们的Generator。得到采样的序列后,我们把这一堆序列扔给Discriminator,得到一批输出为1的概率,这堆概率的平均值即我们的reward。这部分正如过程示意图中的下面一部分:

64908b28e60594781fa75c6f7c5ff6c540a7dd56

用原文中的公式表示如下:

73afd371cb8d86743a916e946ad7c265e0c91c9c

得到了reward,我们训练Generator的方式就很简单了,即通过Policy Gradient的方式进行训练。最简单的思想就是增加reward大的动作的选择概率,减小reward小的动作的选择概率。

Discriminator模型和训练

Discriminator模型即一个分类器,对文本分类的分类器很多,原文采用的是卷积神经网络。同时为了使模型的分类效果更好,在CNN的基础上增加了一个highway network。有关highway network的介绍参考博客:https://blog.csdn.net/l494926429/article/details/51737883,这里就不再细讲啦。

对于Discriminator来说,既然是一个分类器,输出的又是两个类别的概率值,我们很自然的想到使用类似逻辑回归的对数损失函数,没错,论文中也是使用对数损失来训练Discriminator的。

04206012c26f811d2937b83989586a26ea3018ea

结合oracle模型


可以说,模型我们已经介绍完了,但是在实验部分,论文中引入了一个新的模型中,被称为oracle model。这里的oracle如何翻译,我还真的是不知道,总不能翻译为甲骨文吧。这个oracle model被用来生成真实的序列,可以认为这个model就是一个被训练完美的lstm模型,输出的序列都是real-world数据。论文中使用这个模型的原因有两点:首先是可以用来产生训练数据,另一点是可以用来评价我们Generator的真实表现。原文如下:

db52874ac408b5f6146a0eabd23b3e3c773bf3f9

我们会在训练过程中不断通过上面的式子来评估我们的Generator与oracle model的相似性。

预训练过程

上面我们讲的其实是在对抗过程中Generator和Discriminator的训练过程,其实在进行对抗之前,我们的Generator和Discriminator都有一个预训练的过程,这能使我们的模型更快的收敛。

对于Generator来说,预训练和对抗过程中使用的损失函数是不一样的,在预训练过程中,Generator使用的是交叉熵损失函数,而在对抗过程中,我们使用的则是Policy Gradient中的损失函数,即对数损失*奖励值。

而对Discriminator来说,两个过程中的损失函数都是一样的,即我们前面介绍的对数损失函数。

SeqGAN模型流程

介绍了这么多,我们再来看一看SeqGAN的流程:

75809e2e57d24cb8cedb49cfb522f86b32392381

3、SeqGAN代码解析

这里我们用到的代码高度还原了原文中的实验过程,本文参考的github代码地址为:https://github.com/ChenChengKuan/SeqGAN_tensorflow

参考的代码为python2版本的,本文将其稍作修改,改成了python3版本的。其实主要就是print和pickle两个地方。本文代码的github地址为:https://github.com/princewen/tensorflow_practice/tree/master/seqgan

代码实在是太多了,我们这里只介绍一下代码结构,具体的代码细节大家可以参考github进行学习。

3.1 代码结构

本文的代码结构如下:

6f82f7d8e1694271bbeaff9f599814d976182662

save:save文件夹下保存了我们的实验日志,eval_file是由Generator产生,用来评价Generator和oracle model相似性所产生的数据。real_data是由oracle model产生的real-world数据,generator_sample是由Generator产生的数据,target_params是oracle model的参数,我们直接用里面的参数还原oracle model。

configuration : 一些配置参数

dataloader.py: 产生训练数据,对于Generator来说,我们只在预训练中使用dataloader来得到训练数据,对Discriminator来说,在预训练和对抗过程中都要使用dataloader来得到训练数据。而在eval过程即进行Generator和oracle model相似性判定时,会用刀dataloader来产生数据。

discriminator.py:定义了我们的discriminator

generator.py :定义了我们的generator

rollout.py:计算reward时的采样过程

target_lstm.py:定义了我们的oracle model,这个文件不用管,复制过去就好,哈哈。

train.py : 定义了我们的训练过程,这是我们一会重点讲解的文件

utils.py : 定义了一些在训练过程中的通用过程。

下面,我们就来介绍一下每个文件。

3.2 dataloader

dataloader是我们的数据生成器。

e3c34f12f03ee2470edcbd68a42411ba7c72bb75

它定义了两个类,一个时Generator的数据生成器,主要用于Generator的预训练以及计算Generator和Oracle model的相似性。另一个时Discriminator的数据生成器,主要用于Discriminator的训练。

3.3 generator

generator中定义了我们的Generator,代码结构如下:

81f7636e842f274673e3a0f616152d82163fddee

build_input:定义了我们的预训练模型和对抗过程中需要输入的数据

build_pretrain_network : 定义了Generator的预训练过程中的网络结构,其实这个网络结构在预训练,对抗和采样的过程中是一样的,参数共享。预训练过程中定义的损失是交叉熵损失。

build_adversarial_network: 定义了Generator的对抗过程的网络结构,和预训练过程共享参数,因此你可以发现代码基本上是一样的,只不过在对抗过程中的损失函数是policy gradient的损失函数,即 -log(p(xi) * v(xi):

 
 


self .pgen_loss_adv = - tf.reduce_sum(

tf.reduce_sum(

tf.one_hot(tf.to_int32(tf.reshape( self .input_seqs_adv,[ -1 ])), self .num_emb,on_value= 1.0 ,off_value= 0.0 )

* tf.log(tf.clip_by_value(tf.reshape( self .softmax_list_reshape,[ -1 , self .num_emb]), 1e-20 , 1.0 )), 1

) * tf.reshape( self .rewards,[ -1 ]))


build_sample_network:定义了我们Generator采样得到生成序列过程的网络结构,与前两个网络参数是共享的。

那么这三个网络是如何使用的呢?pretrain_network就是用来预训练我们的Generator的,这个没有异议。然后在对抗时的每一个epoch,首先用sample_network得到一堆采样的序列samples,然后对采样序列的对每一个时点,使用roll-out-policy结合Discriminator得到reward值。最后,把这些samples和reward值喂给adversarial_network进行参数更新。

3.4 discriminator

discriminator的文件结构如下:

e82b20577a40b5812b9322c05eb1173322b96f26

前面的linear和highway函数实现了highway network。

在Discriminator类中,我们采用CNN建立了Discriminator的网络结构,值得注意的是,我们这里采用的损失函数加入了正则项:

 
 

with tf.name_scope( "output" ):
W = tf.Variable(tf.truncated_normal([num_filters_total, self .num_classes],stddev = 0.1 ),name= "W" )
  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值