Alleviating Exposure Bias via Contrastive Learning for Abstractive TextSummarization

文章提出了一个方法来解决文本摘要模型在训练和推理阶段的差异,即偏差暴露问题。通过对比学习,模型在训练中不仅能增加黄金摘要的可能性,也能降低低质量(白银摘要)摘要的出现,从而缓解暴露偏差。实验显示,这种方法能有效提升模型在不同数据集上的性能。
摘要由CSDN通过智能技术生成

论文地址

摘要

文本摘要模型在训练和推理阶段存在着差异,即偏差暴露问题。在训练阶段,在teacher-force下,这些模型被优化,以最大限度地提高黄金摘要的可能性,给定黄金摘要令牌作为解码器的输入,而在推理时,给定的令牌被生成的令牌替换。因此,很可能产生低质量的摘要。.为了解决这个问题,我们建议利用对比学习来降低这些低质量摘要的可能性,同时增加黄金摘要的可能性。由于我们的解决方案扩展了模型在训练期间感知到的状态,因此我们希望可以缓解暴露偏差问题。我们通过实验证明,我们的方法有效地证明了不同数据集上最先进模型的性能。

1、介绍

大多模型在经过预训练之后,通过teacher-forcing的方式进行下游任务微调。然而teacher-forcing会导致训练阶段和推理阶段的差异。他们的训练对象是最大化黄金摘要中每个token的可能性,给定其以前的token。然而,在推断时,黄金摘要token不可用,它们必须被生成的token替换。这意味着模型在训练时经常在状态空间的有限部分进行优化,因此其性能极有可能下降。这个问题就是众所周知的偏差曝光问题。它可能会导致沿着生成的token快速累积的严重错误。因此,这些模型特别容易产生意想不到的摘要,与黄金摘要相比,我们称之为“白银摘要”。更糟糕的是,由于这个问题,愚蠢的摘要往往包含虚假事实,其表面形式可能与文本相似,但实际上与其原始含义相反。

为了减轻暴露偏差,我们建议利用对比学习方法来扩展模型在训练过程中感知到的状态。我们不仅希望通过最大似然估计(MLE)增加黄金汇总的可能性,而且还希望通过对比学习(CL)在训练过程中降低白银汇总的可能性。在一定程度上,它有助于防止模型生成白银摘要。这种方法也可以被视为一种特殊的数据增强策略,它使模型能够从正样本(黄金摘要)和负样本(白银摘要)中学习。当推理时生成的银色摘要明确参与另一轮训练时,训练和推理之间的差异可以减少,从而可以减轻暴露偏差。

我们在三个基准数据集上进行实验以验证我们的方法,包括Xsum,CNNDM和Multi-News。实验结果表明,我们的方法可以有效提高最近发布的PEGASUS的SOTA模型的性能。

2、方法

2.1、问题定义

在teacher forcing下,学习目标是在输入文本 X 和黄金摘要中先前的Tokens  y<i的条件下,最大化黄金摘要 Y 中每个Token y_{i} 的可能性。损失函数定义为负对数似然 (NLL),如下所示:

 其中f(y_{i}|X,y<i) 是黄金摘要 Y 的第 i 个Token的对数似然。

在推理时,模型必须使用生成的token  来预测token \hat{y_{i}} 。通常,基于beam search 分数S,beam search算法被用于在每个时间步长获取多个备选方案。然后,模型通过波束搜索一个token一个token地生成候选摘要,并选择波束搜索得分最高的一个作为输出摘要。具有与输入文本X相关联的m个Tokens的一个备选序列 \hat{Y} 的波束搜索得分如下计算:

 其中 是生成序列 \hat{Y} 的第 i 个Token的预测对数似然,表示早于 token\hat{y_{i}} 的token,\beta是与序列长度相关的附加指数惩罚。对于文本摘要任务,\beta小于1.0,以避免生成冗余信息。

当通过NLL损失L_{nll}对数据集进行训练时,黄金摘要的分数S预计会上升,因此更有可能将黄金摘要作为候选摘要之一,从而选择黄金摘要作为生成的候选摘要中的最终输出。

但是,也可能存在着更高的得分S,然而质量却低的候选摘要。当它作为输出摘要时,被称为“白银摘要”。白银摘要的出现可归因于差异问题,因为 seq2seq 模型只能在训练时观测黄金摘要,而模型需要在推理时评估大量看不见的替代方案。这个问题就是众所周知的曝光偏差。

2.2、对比学习

为了缓解上述问题,我们建议在训练期间通过对比学习明确降低银色摘要的分数S,这是受到提取性摘要中使用的类似学习方法的成功启发。

具体来说,在我们的方法中,seq2seq模型进行了优化,以确保“pos分数”高于“neg分数”。对于相同的文本X,pos分数S(Y|X) 通过公式(2)计算黄金摘要,而neg分数 S(\hat{Y}|X) 的计算方式相同,但使用白银摘要。margin ranking loss损失被定义为增加pos分数,同时降低neg分数,如下所示:其中\gamma是margin value。

 注意,如果pos分数高于超过边缘值\gamma的neg分数,则无法优化模型,因为当L_{con}值为零时,梯度也为零。为了最有效地利用训练数据并防止模型欠拟合,我们还将NLL损失包含在整体损失函数中,即:

 2.3、模型训练

通过对比学习训练 seq2seq 模型的工作流程如图 1 所示。pos分数和neg分数都是基于相同的有编码器和解码器结构的seq2seq模型计算的。我们的方法不局限于原版 seq2seq 模型,因此 seq2seq 模型还可以包含复制机制和覆盖机制。虽然它们共享输入文档的编码器输出,但这两个分数分别使用黄金摘要和银色摘要作为解码器输入来计算。这意味着我们的方法主要在解码器端施加影响

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值