ERICA 代码解读

目录

前言

pretrain

数据预处理

模型训练

MLM/RD loss

MLM/ED loss

小结

finetune

总结


前言

论文链接:https://arxiv.org/abs/2012.15022

开源链接:https://github.com/thunlp/ERICA

这是一篇预训练模型,主要创新点就是提出了两个辅助性预训练任务来帮助PLM更好地理解实体和实体间关系:

(1) 实体区分任务,给定头实体和关系,推断出文本中正确的尾实体。

(2) 关系判别任务,区分两个关系在语义上是否接近,这在长文本情景下涉及复杂的关系推理。
 

为了避免灾难性遗忘,作者同时还加了masked language modeling (MLM)这一传统任务,所以总loss就是:

ED就是实体区分任务、RD就是关系判别任务、MLM就是传统屏蔽任务

更多详细解读可以看如下,笔者不再累述,本篇主要目的是解读代码。

ERICA: 提升预训练语言模型实体与关系理解的统一框架

pretrain

数据预处理

该部分代码逻辑在./pretrain/prepare_pretrain_data

get_distant.py:数据清洗,实体抽取和关系抽取

remove_test_set.py:区分训练集和测试集

sample_data.py:tokenized化,通过这样预处理。

这里没什么要说的,笔者比较感兴趣的是实体关系抽取是怎么做的。其实很简单,这里没有什么模型啥的,最主要的就是靠下面几个文件:

all_triple.txt:定义了实体关系

all_name_to_Q.json:实体名到类型的一个json

all_Q.json:所以实体类型id的。

关于实体抽取就是匹配,依靠上述文件,只要匹配到就得到实体。关于关系抽取更简单了,只要实体类型定了那么依靠all_triple.txt就确定了关系。

其中./pretrain/data/DOC/sampled_data/下就是官方给出的一个预处理完的数据结果,可以看看

模型训练

主要逻辑是在./pretrain/code/pretrain下,主入口就是main.py,主要就是:

根据论文我们知道模型主要涉及到三部分loss【ED/RD/MLM】

红色框的doc_loss就是【mask loss + 关系判别即 MLM + RD】,绿色框的wiki_loss就是【mask loss + 实体区分即MLM + ED】

我们来一部分一部分看,主要是在model.py中

可以看到主要就是对应两个函数236行和239行即get_doc_loss和get_wiki_loss函数,需要注意的是两个函数的输入是不一样的,即batch[0]和batch[1],关于输入数据的格式可以看dataset.py:

主要就是730行,其实就是get_doc_batch和get_wiki_batch两个函数。好了,大概代码逻辑框架知道了,下面分开看:

MLM/RD loss

数据输入就是:get_doc_batch

模型就是:get_doc_loss

如下是get_doc_loss

可以看到,

----------------------------------------------------小分界线 MLM LOSS-------------------------------------------------

MLM loss就是用的huggingface框架的BertForMaskedLM或RobertaForMaskedLM【作者即想实验bert又想实验roberta】,但是由于huggingface框架的MaskedLM返回只有loss,代码中后续还想使用sequence output ,所以作者将transformers源码下载了下来,修改了一下,其实就修改了源码中pretrain\code\pretrain\transformers\modeling_bert.py中的一行即:

增加了一行1014代码,与不加之前的区别在于多返回了个变量即sequence_output,这个后续再算其他两种loss【RD/ED】要用。

同时通过看MLM loss我们可以发现mask_tokens这个函数,这个函数主要就是写了传统的mask策略如下

从这里还可以看到,有个not_mask_pos参数,就是传统的mask都是随机mask,这个字段代表一些位置不能被mask,即使随机到了也不行,比如对于本项目中我们在关系判别任务中,头尾实体是不能被mask的,因为我们需要其预测关系,所以not_mask_pos就会记录一些实体位置。

总之我们从该项目的MLM loss这里,我们至少可以学到三点:

(1) 以后我们有什么自己的mask策略想法,想落地实现,其实就是仿效改这个函数。

(2) 只需要修改完上述函数,直接传到对应的huggingface框架下的ForMaskedLM【比如BertForMaskedLM】,就可以直接到返回的Loss,进而进行MLM语言屏蔽模型训练

(3) 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可

---------------------------------------------------小分界线 MLM LOSS-------------------------------------------------

以上的MLM loss就是传统的预训练模型,不是本文的创新点,下面我们来看看论文的创新点RD loss也即关系区分任务【接着看上图的get_doc_loss函数,为了方便,这里再放一次】

作者这里用了对比学习:正样本即具有相同远程监督标签的关系表示,负样本与此相反,关于关系的表征,就是其对应的两个实体的简单拼接,即上述代码的173行得到的hidden。

start_re_output和end_re_output可以看做是头实体和尾实体表征。

context_output就是我们上述修改transformers源码返回的sequence ouput

h_mapping和t_mapping是batch传进来的,可以通过get_doc_batch看到就是代表的实体位置,然后通过和context_output相乘就可以滤除全部头实体和尾实体的编码表征

至此用pair_hidden【hidden】和relation_label通过对比学习计算loss【NTXentLoss_doc函数】

对比学习原理这里不在累述,感兴趣的可以看笔者另外一篇博客:

https://blog.csdn.net/weixin_42001089/article/details/117930433

这里对应的公式就是:

该小节的函数get_doc_loss最后返回就是m_loss和r_loss即MLM loss和ED loss也即屏蔽语言模型loss和关系区分loss

MLM/ED loss

数据输入就是:get_wiki_batch

模型就是:get_wiki_loss

首先206行返回的就是mlm loss,前面已经讲过,这里不在累述,一模一样,重点看看ED loss

他的原理是根据头实体和关系预测尾实体

start_re_output可以看出是头实体,而query_re_output可以看做是关系,我们知道paper的关系表征是头尾实体的简单拼接,所以query_re_output是通过query_mapping得到的,可以理解为query_mapping是当前关系对应头尾实体位置,通过和context_output相乘就过滤出对应头尾实体,进而进行拼接得到关系表征,关于query_mapping是batch得到的,可以看get_wiki_batch

该小节的函数get_wiki_loss最后返回就是m_loss和r_loss即MLM loss和RD loss也即屏蔽语言模型loss和实体区分loss

小结

(1) 新加的两个辅助任务是分开进行的【过两次模型】,但二者每次都顺便带了mlm loss

(2) mlm 部分给了很多落地启发,即自己有了个什么想法,能快速使用transfomers实现,尤其二次预训练,甚至我们可以改源码。

finetune

代码在finetune,这里面每一个文件夹代表一个下游任务,没什么可讲的,主要就是用上述pretrain得到的模型去热启就行了。

总结

(1) 以后我们有什么自己的mask策略想法,想落地实现,其实就是仿效改这个函数。

(2) 只需要修改完上述函数,直接传到对应的huggingface框架下的ForMaskedLM【比如BertForMaskedLM】,就可以直接到返回的Loss,进而进行MLM语言屏蔽模型训练

(3) 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可

欢迎关注笔者的微信公众号,更多好文章:


​​​​​​​

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值