SimCTG:缓解GPT2在生成任务上token的各向异性
GPT2无疑是目前在文本生成任务上使用最多的模型,但由于decoder类模型生成的token具有各向异性(即一个句子中token向量距离较近,容易发生重复生成的问题)。2022年2月,腾讯AI lab针对这一问题提出了SimCTG 一种对比损失函数,来缓解token各向异性,同时在解码端也使用了对比搜索Contrastive Searching来提高解码搜索能力。
paper地址 A Contrastive Framework for Neural Text Generation
代码地址 yxuansu/SimCTG: A Contrastive Framework for Neural Text Generation
github上介绍很齐全,包括英,中,日,韩四国语言的文本生成模型,还有生成对话模型,Open-End故事生成模型,以及除了GPT2这种decoder外,encoder-decoder模型如何BART和T5如何使用SimCTG进行摘要任务。
Playground
-
中,英,日三国语言,文本生成实例 simCTG.ipynb(github.com)
-
Open-End 故事生成 在WritingPrompts 和 ROCStories数据进行fine-tune:SimCTG/story_generation at main · yxuansu/SimCTG (github.com)
-
使用Bart和T5进行摘要任务 SimCTG/SimCTGEncDec at main · yxuansu/SimCTG (github.com)
上面是我写的一个例子,可以在上面中文对话生成代码链接中查看全部内容,感觉根据对话生成的内容还可以吧。
模型介绍
为了解决GPT2生成任务解码时,有时重复生成的问题,是由于推理出的token表示空间的各向异性。SimCTG在传统损失上添加了对比损失函数,使得同一句子不同token表示差异性尽量大,缓解各向异性的表示问题。
L
S
i
m
C
T
G
=
L
M
L
E
+
L
C
L
L_{SimCTG}=L_{MLE}+L_{CL}
LSimCTG=LMLE+LCL
最后Loss等于最大似然估计和对比loss的和
解码也使用类似loss计算的,其中s就是hx(词向量)的相似度得分
总结
本文就是为了解决GPT2等文本生成模型,由于生成token的各向异性,导致degeneration 情况发生,所以提出了让生成的token的差异性更大,在训练中体现在修改loss计算,在推理中体现在decode解码方式。总的来说确实可以改进生成文本的重复生成问题,然后探讨了在故事生成,文档生成,多种语言生成,以及在生成摘要任务(T5,Bart)上的使用,创新性不足,但实验的工作量很多,有一定借鉴加载,而且源码githud代码库写的很好,方便复现和实验。
后续工作
Magic 模型 https://github.com/yxuansu/MAGIC
添加CLIP相似度得分,来生成与图片相关的文本,我会在下面CLIP的后续工作中介绍,而有想了解CLIP的可以看我前面的文章。CLIP学习笔记