CogLTX : bert处理长文本代码解析

前言

bert处理长文本,这里不是改transformer这条路,关于目前transformer变种感兴趣的可以看下:

https://blog.csdn.net/weixin_42001089/article/details/114452385

本文的方法是使用了一个评分模型
核心思想就是:
将长文本分割成一个个块,设计一个judge模型,来负责给各个块打分,然后下游使用得分高的块来组合在一起进行训练!
关于分数怎么定义,对于qa这种,可以用query得到一些天然的关键句,依次监督judge训练,但是对于普通的分类任务这种,没有这种天然监督,那是通过Loss来动态更新的,即如果drop一个块后,loss徒增,那么其就是关键句,就可以打标为正样本。

大体上是一种时间换空间的做法。
在这里插入图片描述

更多细节直接看论文吧,本篇主要结合代码看一下其具体实现的一些细节
论文:
http://keg.cs.tsinghua.edu.cn/jietang/publications/NIPS20-Ding-et-al-CogLTX.pdf

代码:
github: https://github.com/Sleepychord/CogLTX

(1)代码大体框架,这里可能先看的话有一些不理解,这里只是梳理了一下大纲,可以结合后面细节返回来看看。

(2)这里只看了分类,qa没看,qa更简单吧,其没有loss训练更新relevance部分,因为其有天然关键句,但是这部分打标应该在数据预处理部分也会涉及到。

(3)分类任务没有天然的label,是不断通过loss动态更新的,关于其初始化,可以使用BM25等先验知识来初始化打标一下,代码中20news这个例子其实并没有用,全部初始化为0的。

代码大体框架

(1)main_loop.py
主流程

(2)buffer.py
Block类:其是封装一个块的,一些比较重要的属性有relevance,estimation,pos(主要就是在更新数据的relevance,estimation时需要全局唯一定位)

buffer.py中的Buffer类:其是封装一个样本的,即一个样本是由多个block对象组成的列表,其有比较重要的方法如export,export_relevance用来给tensor装数据的。

(3)data_helper.py
BlkPosInterface类:主要就是两个训练流程之间数据需要交互,以及我们块怎么选取一些块组合在一起等策略都是在这里定义的。比较重要的方法是build_random_buffer:judge部分的块选取组合策略、build_promising_buffer:reasoner部分的块选取组合策略、collect_estimations_from_dir:加载judge计算的estimations,更新全局数据集的estimations,reasoner的build_promising_buffer选取策略要用到、 apply_changes_from_dir就是加载reasoner 部分通过_intervention计算的结果来动态更新全局的relevance。

(4) introspector_module.py
就是judge模型,比较重要的就是
training_step:里面就是正常的训练分类模型(关键句即token级别的分类,标签是relevance),比较重要的是self._write_estimation(buf, _score_blocks(buf, torch.sigmoid(logits[i]))),就是保存好estimation分数,供reasoner使用

(5)reasoner_module.py
就是具体任务模型,比较重要的就是
training_step:里面也是正常的训练分类模型(真真下游任务分类),比较重要的是:
self._intervention:对于分类模型,用loss来计算块的relevance。
self. _write_changes:就是保存好relevance,供apply_changes_from_dir加载。

(6)models.py
就是一些基础模型

(7)memreplay.py
_score_blocks:计算块分数即estimations

数据预处理

首先是数据预处理部分,其主要是将长文本切分为块,即如下3个文件夹对应3个不同数据集的预处理脚本。下面就挑20news这个来看看吧。
在这里插入图片描述
首先就是按标点符号分隔开,如果两个逗号中间的文本过长(大于B=63),那就按B再切分,然后再合并各个块,合并的原则就是看标点符号,举个例子吧。

假设有5个块,第一个块结尾是逗号,第二个结尾是句号,第三个和第四个是由于原来该块过长被分开成两个,第五个块是以句号结尾的。每一个块有一个cost,如逗号是2,句号是1,第三个和第四个都是8。
那么我们从第五个块开始向前合并,首先看第四个块,其cost是8,记下来,第三个块也是8,我们一直向前找到最小的即第二块的1,最好就将第2,3,4,5合并成一个。其实就是说如果合并逗号,那么就将两句以逗号连接的文本分到了两个块中,代价比较大,如果是句号呢,其前后两句话被分到两个块就比较合理。来看看代价的定义吧:

end_tokens = {'\n':0, '.':1, '?':1, '!':1, ',':2}

可以看到换行符代价基本没有,句号问号叹号代价次之,逗号再次之,也就是说句号问号叹号等前后两句话被分到两个块比较合理,比较极端的如对于换行符就没有代价,还有第三块和第四块本是一句连续的文本,所以其代价cost应该更大,代码中定义的是8 。

大体原理就是如上,下面结合代码看一下 process_20news.py

主要就是buffer.py下的Buffer类的split_document_into_blocks方法

假设有一段文本,经过robert 的 tokenizer后
在这里插入图片描述
然后下面的代码就是最关键的
在这里插入图片描述
64-71行就是按64行定义的标点符号分割成一个一个块,71行的poses结果如下:
[(-1, 0), (8, 2), (16, 2), (36, 1), (39, 2), (51, 2), (61, 2), (73, 2), (81, 2), (84, 1), (89, 2), (130, 2), (139, 2), (145, 2), (161, 1), (171, 1), (180, 1), (188, 2), (204, 1), (221, 2), (227, 1), (261, 2), (278, 1), (285, 1), (422, 0)]

第一个元素是代表块的(标点)起始位置,第二个元素就是原理部分所说的cost代价。
73行到76就是检查一下有的块是不是长度大于论文中所说的B即这里的BLOCK_SIZE=63,有的话就再切开,运行到76行时,此时的poses就是:
[(-1, 0), (8, 2), (16, 2), (36, 1), (39, 2), (51, 2), (61, 2), (73, 2), (81, 2), (84, 1), (89, 2), (130, 2), (139, 2), (145, 2), (161, 1), (171, 1), (180, 1), (188, 2), (204, 1), (221, 2), (227, 1), (261, 2), (278, 1), (285, 1), (348, 8), (411, 8), (422, 0)]
由于422到285长度大于63,所以又分出加粗的两个块,且其代价更大为原理部分所说的8。
78行到89行最关键的,也就是原理部分所说的,86行cost和sen_cost对于当前合并其实是定值,所谓的动态是best[j][1],它就是不断的向前找代价最小的,保存到当前的best[i],当前不能无限制合并,确保合并的总长度要小于BLOCK_SIZE,即84行。
89行的best是:
[(0, 0), (0, 6), (0, 6), (0, 5), (0, 6), (0, 6), (0, 6), (3, 11), (3, 11), (3, 10), (3, 11), (9, 16), (9, 16), (9, 16), (13, 21), (13, 21), (13, 21), (13, 22), (13, 21), (18, 27), (18, 26), (18, 27), (20, 31), (20, 31), (23, 43), (24, 55), (25, 59)]
这里大体意思就是(0, 0), (0, 6), (0, 6), (0, 5), (0, 6), (0, 6), (0, 6)是合并到一个块的
(3, 11), (3, 11), (3, 10)是一个块的
(9, 16), (9, 16), (9, 16)是一个块的等以此类推。
90到94行就是根据best记录的合并信息和poses记录的各个块的起始位置得到最终的各个块起始和最终位置,即intervals是:
[(412, 423), (349, 412), (286, 349), (228, 286), (205, 228), (146, 205), (85, 146), (37, 85), (0, 37)]
一共将长文本分为了9个块。
最后就是将得到的块封装一下即封装成Block类,然后插入到ret,ret是buffer类
在这里插入图片描述
其中封装成的Block类定义就是
在这里插入图片描述
总结一下:
预处理的结果,返回两个东西,一个是ret,其是buffer类,它的属性blocks是一个列表,其中一个个元素就是封装好的Block块。一个是cnt就是记录的块的总数即这里就是9

主训练

主要分为两部分即introspector和reasoner。其中interface是数据接口用来筛选和组合最新的块,来更新

在这里插入图片描述
其中54到57行就是introspector训练,对应的就是论文中的judge
59到63行就是reasoner就是真真的下游任务
64到63行就是在分类任务中,去掉某一个句子看loss是否变大,来更新句子的label(关键句)
需要明确block分装定义的几个变量

在这里插入图片描述ids是当前块文本经过tokenizer编码后的id,pos是在整个数据集上对各个块一个位置,方便卫衣定位,blk_type是区分sentence_a和sentence_b的,对于分类来说就是将选中的块拼接成一句话,没有所谓的sentence_a和sentence_b,这里都是1,对于问答,query就是0,后面文本的是1,relevance就是关键句标识(>=1就是相关的,说白了就是训练judge的标签,对于qa问题,可以得到天然的关键句标识,因为我们知道答案在哪个块中,但是对于分类我们不知道,所以是通过drop一句话来看loss变化,如果变化太大那就是关键句,通过动态给打标,然后judge就是依据这个打标的label来更新自己的网络,注意这里的关键句label是非0即1,即如果是关键句,该块所有token的label都是1),estimation就是judge预测的当前块是不是关键句,也就是论文中所说的最关键的score,它是一个平均值,假设当前块是“我爱中国”,introspector(judge)预测是[0.1,0.8,0.7,0.6],那么综合所有token取mean作为该块的最终分数即(0.1+0.8+0.7+0.6)/4。

还有一个比较关键的就是关于estimation动态更新relevance,也就是动态打标这个怎么打呢?
其实这里设置了两个门限(上门限:0.2,下门限:-0.05)

假设没有drop前的loss是a,drop某个块后loss是b
(1)那么b-a如果大于上门限,那么就是关键句,那么更新其标签即relevance+1
(2)那么b-a如果小于下门限,那么就是说drop后loss还降了,那么就是说其更不重要了,更新其标签即relevance-1。
(3)其他情况就是正常,不更新了,保持原标签。

============================================================
下面我们结合代码来看看,脑海中始终记住两个重要变量即
relevance:相当于标签,qa有天然的,分类模型需要通过Loss动态更新。
estimation: introspector(judge)预测的块的标签(分数)。
main_loop.py就是主流程训练,最关键的还是上面那副图,这里再拿过来
在这里插入图片描述同时吧论文中的算法流程拿过来一起对比着看
在这里插入图片描述

两个模型(introspector,reasoner),一个数据接口interface。
训练大体上分为两部分,第一步的是introspector,就是judge训练,第二部分是reasoner就是具体的下游任务。两个训练过程全程使用interface这个数据接口,包括怎么选取块组合以及动态打标的交互等都是通过这个数据接口完成的,很关键,其其实是data_helper.py下的BlkPosInterface类。

introspector训练

代码中54行build_random_buffer函数其实是在做算法流程中的第5和第7行,即选择一部分快组合成训练数据。
在这里插入图片描述该函数的78行到84行就是随机选。对应算法图片的第5行。

###注意:qbuf和dbuf是一个样本,还记得在数据预处理部分吗?两者其实都是一个列表,列表中的元素都是一个buffer对象,qbuf列表其实就一个元素(一个块)该块对应的文本是cls,dbuf是真真的文本分割成块组成的一个块列表,那么为什么要qbuf这个东东呢?因为我们在不断动态选取块拼接形成新的模型输入时,因为用的是bert,第一个token必须是cls,所以为了方便,我们每个样本都是以qbuf开始的。

86行到92行对应算法图片的第7行。其中86行的dbuf.filtered就是根据relevance将块分为正负样本,可以看到relevance大于等于1的被视为正样本即pbuf,小于1的被视为负样本是nbuf,88行是抽取所有正样本,89行是抽取了部分负样本。91行就是将两个收集的结果作为当前样本的块组合。同时可以看到不论是83行还是91行都是以qbuf开头的,即一个样本是要以cls开头的!后面是各个块拼接组成一个完整的输入(其实各个块在数据预处理部分都在结尾加上了sep,所以总体拼接后满足bert输入形式)。
好啦,回到 introspector的训练流程中,上述得到了intro_dataset,通过

introspector.set_dataset(intro_dataset)

将其设置了introspector的数据集,这里其实也没啥就是分了一下train,eval和test数据集。
最后就是

trainer = _create_new_trainer(epoch + 1, logger_intro) 
trainer.fit(introspector)

训练introspector了。那么introspector模型是什么呢?很简单即introspector_module.py即训练流程是
在这里插入图片描述102-103行相当于用当前这一个batch给bert需要的input_ids,attention_mask,token_type_ids输入装数据
在这里插入图片描述大体上就是在238行到243行,ids就是换成bert的ids,att_masks就是1,对于分类就一个sentence。

 buf.export_relevance(device=self.device, out=inputs[3, i])

就是根据relevance来给label tensor装数据
在这里插入图片描述可以看到,当一个块的relevance>=1时label就是1,否则就是0啦。

 loss_introspector, logits = self.introspector(*inputs[:3], labels=inputs[3])

就是用上面装好的数据进行训练啦, loss_introspector就是整个token级别的loss,logits就是预测结果,注意是各个token的预测结果。
self.introspector就是一个简单的robert分类模型,这里不在累述。

 for i, buf in enumerate(bufs):
     self._write_estimation(buf, _score_blocks(buf, torch.sigmoid(logits[i])))

是将introspector预测的结果转化成分数保存起来,供后续reasoner使用,其实就是论文中judge计算的score,来看看_score_blocks,该函数在memreplay.py中

在这里插入图片描述第8行返回一个列表,记录的是各个块的长度,第11行这里的blk_type其实为了在qa场景下剔除query这个块的,该块不需要计算分数,对于当前我们分类blk_type都是1。通过12行可以看到就是该块内各个token的平均值作为该块的最终score。
最后通过_write_estimation函数将其写在了一个命名为estimations_{}.txt临时文件。

reasoner训练

在这里插入图片描述对应的就是这里的59行到63行。
59行就是加载introspector部分保存的estimations_{}.txt临时文件,更新全数据的estimation,注意这里的self.d就是一个全局的块映射,还记得之前说的pos这个吗,就是为了唯一全局定位块的,如下
在这里插入图片描述60行就是reasoner的组合块策略了,对应算法图片的话就是11到14行。
在这里插入图片描述102还是使用relevance将样本分为正样本和负样本两部分(正样本就是该块是关键的),然后106-108行就是在上述负样本中根据judge计算的estimation进行排序,取topk,这里可能有一点小小的疑问?就是根据estimation统一对所有块进行排序不就行了?为什么要先用relevance分一下呢?这就是说对于qa问题我们天然知道一些块就是关键的,我们不必通过judge来预测的estimation来选,我们就是强行要把这些保留下来的,即这里的pbuf,在nbuf这些负块中我们不知道哪些重要,才用judge的estimation来排序选择。111到114行就是相当于选取estimation高的topk,115到118就是相当于不能总是选estimation高的,因为judge也有误差,判断的也不一定标准,我们得在nbuf随机选一下,给那些低分块也有一定上位的机会!

好啦数据租好了,那么接下来就是训练了,看看reasoner部分的模型:
在这里插入图片描述类似introspector部分,这里的144行-148行就是给bert输入tensor装数据,不再累述,149行就是主模型,也很简单,也是一个robert分类模型,这里就是真真下游任务的分类训练模型。
下面开看看153行到154行的self._intervention函数,他就是用loss,来更新关键句label的,对应到算法图片就是17到21行。可以看到self.config.latent是一个控制参数,对于qa问题,有天然的query监督,就不需要这一步。
在这里插入图片描述这就是根据,loss来更新relevance,比较关键的更新策略就是135行到138行,关于解释见前面讲的原理。需要注意的是,如135行的当loss徒增且此时该块的relevence也比较大时(大于2),才更新它(重要度减1).可以看到这里通过self.write_changes将结果写到changes{}.txt一个临时文件中。
到此reasoner训练结束。回到reasoner主流程中还有最后一部分

 if config.latent and epoch > 1:
       interface.apply_changes_from_dir(config.tmp_dir)

这是更新全局的relevance的,上述我们只是将计算得到的relevance值写到一个临时文件中了,我们还未对当前数据的relevance值更新。

在这里插入图片描述apply_changes_from_dir就是加载changes_{}.txt文件,然后用apply_changes_from_file解析出来relevance,然后用set_property来更新块的relevance。
最新的:
https://mp.weixin.qq.com/s/PtPX8tfyVVa0_8IZw–jUg

看到很多小伙伴私信和关注,为了不迷路,欢迎大家关注笔者的微信公众号,会定期发一些关于NLP的干活总结和实践心得,当然别的方向也会发,一起学习:

在这里插入图片描述

  • 8
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值