Unlimiformer:一个Transformers输入无限长文本思路和中文长文本摘要上的性能实验

Unlimiformer:一个Transformers输入无限长文本思路和中文长文本摘要上的性能实验

1、前言

在处理长文本输入时,以往方法常采用截断(如:max_len=512)、分块(将输入分成多个块,并进行逐块处理)、长文本输入的模型(如:Longformer、BigBird和Reformer等)。由于编码器上下文窗口的固定大小,Transformer 在其最大输入长度上受到限制。本文将介绍一种能输入无限长文本的思路。名为Unlimiformer,可以扩展预训练的编码器-解码器Transformer模型的输入长度,使其能够处理无限长度的输入。传统的Transformer模型因为需要对输入中的每个标记进行注意力计算,因此输入长度通常会被限制在一定的范围内。Unlimiformer通过将注意力计算分散到一个k最近邻索引中,可以处理极长的输入序列。该方法可以应用于各种长文档和多文档摘要任务中,并且可以通过注入Unlimiformer来提高已有的预训练模型的性能,而不需要额外的训练。

文章来源:https://arxiv.org/abs/2305.01625

代码链接:https://github.com/abertsch72/unlimiformer

2、Unlimiformer

2.1、Encoding

为了对超长输入序列进行编码,Unlimiformer采用了重叠块编码的方法,并使用类似Faiss的库将编码后的输入存储在数据存储器中。

2.2、Retrieval-augmented cross-attention

Retrieval-augmented cross-attention它在标准的cross-attention上进行了改进,使得decoder不仅仅只关注encoder输入序列的前k个token,而是检索了整个输入序列中的top-k个hidden states,然后针对这些top-k的hidden states进行attention计算。这种方法不仅可以检索整个输入序列,而且计算量和GPU内存的使用也比全局attention更加高效,同时保留了99%以上的attention质量。Retrieval-augmented cross-attention的具体实现过程中,引入了一个datastore来存储编码后的输入序列,使用kNN搜索来检索hidden states,同时通过Attention reformulation的方法来优化注意力计算过程,使得可以使用单个datastore来支持所有attention heads和decoder layers的检索,从而大大降低了时间和空间复杂度。图 2 显示了对任何序列到序列转换器架构的通用更改。完整的输入使用块中的编码器进行编码并存储在数据存储中;然后,在每个解码步骤中查询编码隐藏状态的数据存储。kNN 搜索步骤是非参数的,可以注入任何预训练的 seq2seq 转换器。搜索步骤将注意力重新制定为空间效率。在下面例子中,编码器的最大输入长度为 2 个标记。6 令牌输入以块编码并存储在数据存储中。在交叉注意之前,将 Unlimiformer 注入每个解码器层。在 Unlimiformer 中,执行 kNN 搜索以从数据存储为每个注意力头选择 2 个标记上下文;然后,使用整个输入序列的键和值计算交叉注意。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nxejMTQ1-1684218645822)(F:\weixin\imgs\image-20230516133046069.png)]

2.3、Attention reformulation

简单的说,Attention reformulation是一种针对transformer模型encoder-decoder结构中的attention机制进行改进的方法。具体而言,传统的transformer模型中,encoder和decoder各自有一个固定的context window,但是在不同的解码阶段,不同的信息可能是相关的,不同的attention头也可能关注不同类型的信息。因此,一个固定的context window可能会浪费精力在某些attention头并没有强烈关注的token上。Attention reformulation允许每个attention头在每个解码步骤中从完整的输入序列中选择一个独立的context window。这通过在decoder之前注入一个Unlimiformer查找来实现:在交叉注意力之前,模型在外部数据存储中执行一个k最近邻搜索,以选择每个解码器层每个attention头要关注的一组token。

3、局限性

  1. 需要一个外部的数据存储器来存储输入序列的编码表示,这会增加存储和计算成本。
  2. 需要进行KNN搜索来选择每个注意力头的上下文窗口,这也会带来一定的计算复杂度。
  3. 在处理非常长的输入序列时效果很好,但在处理较短的输入序列时可能会带来一些额外的计算开销。

4、生成式长文本摘要的插拔实践

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    EarlyStoppingCallback,
    set_seed, WEIGHTS_NAME,
)

...
# 常规定义模型
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    use_auth_token=training_args.use_auth_token,
)


# 转换成Unlimiformer以兼容无限长度文本输入
from unlimiformer import Unlimiformer
from random_training_unlimiformer import RandomTrainingUnlimiformer
...
model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs)

5、中文生成式长文本摘要上的实践表现

原文仅在英文的摘要数据集上进行实验,本文在NLPCC中文长文本摘要数据集上进行了实验小试对比:

模型性能(ROUGE-L)
BART49.074
UnlimiformerBart52.45
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值