目录
前言
论文链接:https://arxiv.org/abs/2012.15022
开源链接:https://github.com/thunlp/ERICA
这是一篇预训练模型,主要创新点就是提出了两个辅助性预训练任务来帮助PLM更好地理解实体和实体间关系:
(1) 实体区分任务,给定头实体和关系,推断出文本中正确的尾实体。
(2) 关系判别任务,区分两个关系在语义上是否接近,这在长文本情景下涉及复杂的关系推理。
为了避免灾难性遗忘,作者同时还加了masked language modeling (MLM)这一传统任务,所以总loss就是:
ED就是实体区分任务、RD就是关系判别任务、MLM就是传统屏蔽任务
更多详细解读可以看如下,笔者不再累述,本篇主要目的是解读代码。
pretrain
数据预处理
该部分代码逻辑在./pretrain/prepare_pretrain_data
get_distant.py:数据清洗,实体抽取和关系抽取
remove_test_set.py:区分训练集和测试集
sample_data.py:tokenized化,通过这样预处理。
这里没什么要说的,笔者比较感兴趣的是实体关系抽取是怎么做的。其实很简单,这里没有什么模型啥的,最主要的就是靠下面几个文件:
all_triple.txt:定义了实体关系
all_name_to_Q.json:实体名到类型的一个json
all_Q.json:所以实体类型id的。
关于实体抽取就是匹配,依靠上述文件,只要匹配到就得到实体。关于关系抽取更简单了,只要实体类型定了那么依靠all_triple.txt就确定了关系。
其中./pretrain/data/DOC/sampled_data/下就是官方给出的一个预处理完的数据结果,可以看看
模型训练
主要逻辑是在./pretrain/code/pretrain下,主入口就是main.py,主要就是:
根据论文我们知道模型主要涉及到三部分loss【ED/RD/MLM】
红色框的doc_loss就是【mask loss + 关系判别即 MLM + RD】,绿色框的wiki_loss就是【mask loss + 实体区分即MLM + ED】
我们来一部分一部分看,主要是在model.py中
可以看到主要就是对应两个函数236行和239行即get_doc_loss和get_wiki_loss函数,需要注意的是两个函数的输入是不一样的,即batch[0]和batch[1],关于输入数据的格式可以看dataset.py:
主要就是730行,其实就是get_doc_batch和get_wiki_batch两个函数。好了,大概代码逻辑框架知道了,下面分开看:
MLM/RD loss
数据输入就是:get_doc_batch
模型就是:get_doc_loss
如下是get_doc_loss
可以看到,
----------------------------------------------------小分界线 MLM LOSS-------------------------------------------------
MLM loss就是用的huggingface框架的BertForMaskedLM或RobertaForMaskedLM【作者即想实验bert又想实验roberta】,但是由于huggingface框架的MaskedLM返回只有loss,代码中后续还想使用sequence output ,所以作者将transformers源码下载了下来,修改了一下,其实就修改了源码中pretrain\code\pretrain\transformers\modeling_bert.py中的一行即:
增加了一行1014代码,与不加之前的区别在于多返回了个变量即sequence_output,这个后续再算其他两种loss【RD/ED】要用。
同时通过看MLM loss我们可以发现mask_tokens这个函数,这个函数主要就是写了传统的mask策略如下
:
从这里还可以看到,有个not_mask_pos参数,就是传统的mask都是随机mask,这个字段代表一些位置不能被mask,即使随机到了也不行,比如对于本项目中我们在关系判别任务中,头尾实体是不能被mask的,因为我们需要其预测关系,所以not_mask_pos就会记录一些实体位置。
总之我们从该项目的MLM loss这里,我们至少可以学到三点:
(1) 以后我们有什么自己的mask策略想法,想落地实现,其实就是仿效改这个函数。
(2) 只需要修改完上述函数,直接传到对应的huggingface框架下的ForMaskedLM【比如BertForMaskedLM】,就可以直接到返回的Loss,进而进行MLM语言屏蔽模型训练
(3) 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可
---------------------------------------------------小分界线 MLM LOSS-------------------------------------------------
以上的MLM loss就是传统的预训练模型,不是本文的创新点,下面我们来看看论文的创新点RD loss也即关系区分任务【接着看上图的get_doc_loss函数,为了方便,这里再放一次】
作者这里用了对比学习:正样本即具有相同远程监督标签的关系表示,负样本与此相反,关于关系的表征,就是其对应的两个实体的简单拼接,即上述代码的173行得到的hidden。
start_re_output和end_re_output可以看做是头实体和尾实体表征。
context_output就是我们上述修改transformers源码返回的sequence ouput
h_mapping和t_mapping是batch传进来的,可以通过get_doc_batch看到就是代表的实体位置,然后通过和context_output相乘就可以滤除全部头实体和尾实体的编码表征
至此用pair_hidden【hidden】和relation_label通过对比学习计算loss【NTXentLoss_doc函数】
对比学习原理这里不在累述,感兴趣的可以看笔者另外一篇博客:
https://blog.csdn.net/weixin_42001089/article/details/117930433
这里对应的公式就是:
该小节的函数get_doc_loss最后返回就是m_loss和r_loss即MLM loss和ED loss也即屏蔽语言模型loss和关系区分loss
MLM/ED loss
数据输入就是:get_wiki_batch
模型就是:get_wiki_loss
首先206行返回的就是mlm loss,前面已经讲过,这里不在累述,一模一样,重点看看ED loss
他的原理是根据头实体和关系预测尾实体
start_re_output可以看出是头实体,而query_re_output可以看做是关系,我们知道paper的关系表征是头尾实体的简单拼接,所以query_re_output是通过query_mapping得到的,可以理解为query_mapping是当前关系对应头尾实体位置,通过和context_output相乘就过滤出对应头尾实体,进而进行拼接得到关系表征,关于query_mapping是batch得到的,可以看get_wiki_batch
该小节的函数get_wiki_loss最后返回就是m_loss和r_loss即MLM loss和RD loss也即屏蔽语言模型loss和实体区分loss
小结
(1) 新加的两个辅助任务是分开进行的【过两次模型】,但二者每次都顺便带了mlm loss
(2) mlm 部分给了很多落地启发,即自己有了个什么想法,能快速使用transfomers实现,尤其二次预训练,甚至我们可以改源码。
finetune
代码在finetune,这里面每一个文件夹代表一个下游任务,没什么可讲的,主要就是用上述pretrain得到的模型去热启就行了。
总结
(1) 以后我们有什么自己的mask策略想法,想落地实现,其实就是仿效改这个函数。
(2) 只需要修改完上述函数,直接传到对应的huggingface框架下的ForMaskedLM【比如BertForMaskedLM】,就可以直接到返回的Loss,进而进行MLM语言屏蔽模型训练
(3) 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可
欢迎关注笔者的微信公众号,更多好文章: