A Diversity-Promoting Objective Function for Neural Conversation Models 论文阅读零散笔记

A Diversity-Promoting Objective Function for Neural Conversation Models

SEQ2SEQ模型用于conversational responses倾向于产生safe, commonplace的response,比如(“I don’t know”)。这篇论文中作者提出了MMI(Maximum Mutual Information)作为object function 而不是原来的MLE。

MLE求得给定input message sequence S S 下,取得target sequence T的公式为:
Tˆ=argmaxT{logp(T|S)} T ^ = a r g m a x T { l o g p ( T | S ) }
上述公式对于高频的generic response敏感,而类似“I don’t know”和“i don’t know what you are talking about”这种回答是高频的,而且是dull的,这种高频可能使回答并不能很好的契合 S S
作者提到了互信息的公式:
logp(S,T)p(S)p(T)
确保S与T相关
Tˆ=argmaxT{logp(T|S)logp(T)} T ^ = a r g m a x T { l o g p ( T | S ) − l o g p ( T ) }
当然实际上采用的是
Tˆ=argmaxT{logp(T|S)λlogp(T)} T ^ = a r g m a x T { l o g p ( T | S ) − λ l o g p ( T ) }
可以理解成是在 Tˆ=argmaxT{logp(T|S)} T ^ = a r g m a x T { l o g p ( T | S ) } 上增加了一个惩罚项 logp(T) l o g p ( T ) ,对于高频的T惩罚力度大,低频的惩罚力度小。

经过贝叶斯公式推导可写成
Tˆ=argmaxT{(1λ)logp(T|S)+λlogp(S|T)} T ^ = a r g m a x T { ( 1 − λ ) l o g p ( T | S ) + λ l o g p ( S | T ) }

因采用的公式的不同,作者提到了两种方法MMI-antiLM和MMI-bdi。当然作者提到它们都可能产生ungrammatical output,因而两种方法在原公式的基础上作出了一定的修改。

MMI-antiLM

p(T)=Ntk=1p(tk|t1,t2,...,tk1) p ( T ) = ∏ k = 1 N t p ( t k | t 1 , t 2 , . . . , t k − 1 )
序列T中每一个token ti t i 出现都是与前面的i-1个token相关的(考虑SEQ2SEQ中attention机制的存在),因而出现的概率是组成它的各个token的联乘形式。
其实由于句子长度各不完全一致,T的长度不是一个定值,导致对于不同长度的T, p(T) p ( T ) 的数量级差别很大。因此我个人认为需要考虑T的长度因素。事实上,作者是这么考虑的,只不过是放到了后面再说。
被改写为
U(T)=Ntk=1p(tk|t1,t2,...,tk1)g(k) U ( T ) = ∏ k = 1 N t p ( t k | t 1 , t 2 , . . . , t k − 1 ) ⋅ g ( k )
其中

g(k)={10ififkγk>γ g ( k ) = { 1 i f k ≤ γ 0 i f k > γ

γ γ 是选定的threshold.
这样公式就变成了:
Tˆ=argmaxT{logp(T|S)λlogU(T)} T ^ = a r g m a x T { l o g p ( T | S ) − λ l o g U ( T ) }
作者的意图有二:
其一,在SEQ2SEQ模型中,上一个输出的单词在很大程度下决定着下一个单词,因而T序列中靠前的单词对于整个序列的影响更大,penalize前面的单词相比penalize后面的单词更能确保diversity
其二,ungrammatical segments 更可能出现在句子的后半部分(特别是长句子)。

上式中公式中 g(k) g ( k ) 的取值决定了,只是对长度超过阈值 γ γ 的序列才施加 U(T) U ( T ) 的惩罚。

MMI-bidi

Tˆ=argmaxT{(1λ)logp(T|S)+λlogp(S|T)} T ^ = a r g m a x T { ( 1 − λ ) l o g p ( T | S ) + λ l o g p ( S | T ) }
由于SEQ2SEQ每一步产生的是每个单词作为输出可能的概率,其每一步的输出都会使得输出中的待选择的序列成指数倍增长,待选序列T太多了,对于每一个T计算 logp(S|T) l o g p ( S | T ) 并不现实,所以实际操作中,先按照
Tˆ=argmaxT{logp(T|S)} T ^ = a r g m a x T { l o g p ( T | S ) } 选择出N-best list, 这N个应该是generally grammatical的, 再用上面的公式进行计算来对这N个序列来rerank。

实际中,作者注意到了序列的长度 Nt N t 在训练中是不可忽视的,因此对于上面公式都加上了 γNt γ N t
即对于MMI-antiLM
Score(T)=p(T|S)λU(T)+γNt S c o r e ( T ) = p ( T | S ) − λ U ( T ) + γ N t
MMI-bidi同

实验部分用来俩数据集 Twitter Conversation Triple Dataset 以及 OpenSubtitles数据集。
具体就不说了。
本文完

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值