信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的《Bert三元关系抽取》模型,我到底悟到了什么?
前言
先上热菜致敬苏神:苏剑林. (2020, Jan 03). 《用bert4keras做三元组抽取 》[Blog post]. Retrieved from https://kexue.fm/archives/7161
建议大家先看苏神的原文,如果您能看懂思路和代码的话我的文章可能对你的帮助不大。
拜读这篇文章之后本人用TF + Transformers 复现了该baseline模型,并在其基础上进行了大量的尝试,直到心累也没有成功复现相同水平的结果,但也有所接近,因此用这篇文章复盘整个过程并分享一些收获和心得。
数据格式与任务目标
数据下载地址:https://ai.baidu.com/broad/download?dataset=sked
数据格式:
{
"text": "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
"spo_list":
[{
"predicate": "出生地", "object_type": "地点", "subject_type": "人物", "object": "圣地亚哥", "subject": "查尔斯·阿兰基斯"},
{
"predicate": "出生日期", "object_type": "Date", "subject_type": "人物", "object": "1989年4月17日", "subject": "查尔斯·阿兰基斯"}]}
简单来说给定一段文本,我们需要从中抽取出多组 S(subject) P(predicate) O(object_type)的关系。
例如:“查尔斯·阿兰基斯–出生日期–1989年4月17日”则是一组我们需要抽取出来的信息。而 P(需要预测的关系)已经给定范围,一共49类关系,具体见 all_50_schemas 。
模型整体思路
这个模型思路的精彩之处:
-
该任务本来应该分成两个模块完成:1.抽取实体(包括S和O)2.判断实体之间的关系,理应至少需要两个模型协同完成,但苏神将实体之间的关系类别预测隐性的放在了O抽取的过程中,即让模型在预测O的时候直接预测O与S的关系P。
-
指针标注:对每个span的start和end进行标记,对于多片段抽取问题转化为N个2分类(N为序列长度),如果涉及多类别可以转化为层叠式指针标注(C个指针网络,C为类别总数)。事实上,指针标注已经成为统一实体、关系、事件抽取的一个“大杀器”。
-
由于一个文本中可能存在多对SPO关系组,甚至可能存在S之间有Overlap,O之间有Overlap的情况,因此模型的输出层使用的是半指针-半标注的sigmoid(类似多标签预测实体的始末位置,与阅读理解相似)这样可以让模型同时标注多对S和O。
-
使用Conditional Layer Normalization 我们需要在预测PO时告诉模型,我们的S是什么,以至于使得模型学习到PO的预测是依赖于S的,而不是看见“日期”就认为是出生年月。具体的内部实现流程也可以参考我的代码,会有介绍。(这各地方也卡了我很久才跑通)最后评估下来这个方法有利也有弊。
复现代码
数据处理
数据读取
def load_data(path):
text_list = []
spo_list = []
with open(path) as json_file:
for i in json_file:
text_list.append(eval(i)['text'])
spo_list.append(eval(i)['spo_list'])
return text_list,spo_list
def load_ps(path):
with open(path,'r') as f:
data = pd.DataFrame([eval(i) for i in f])['predicate']
p2id = {
}
id2p = {
}
data = list(set(data))
for i in range(len(data)):
p2id[data[i]] = i
id2p[i] = data[i]
return p2id,id2p
训练数据处理
这里处理的思路和信息抽取(一)中处理的思路相似,有详细的代码注释:
信息抽取(一)机器阅读理解——样本数据处理与Baseline模型搭建训练(2020语言与智能技术竞赛)
这里主要介绍针对本次任务的几个细节和trick:
- 由于一段文本可能存在多个S,因此遍历一组数据里的所有SPO关系,将所有S的头尾位置放在一个01数组中。
- 对于存在多组SPO关系的样本,在标注PO时,我们只随机选取一个S,理由比较简单,你没办法一下子传入多个S给下一个模型。
- 抽取S时,随机选取一个S的首位置,从所有S的末位置中选取一个与之匹配,如果是完整的S,则对其所有的PO进行标注,否则跳过,该样本作为负样本。这是为了让模型学会并非所有抽取出来的S都有对应的PO关系。
- 由于限制了token长度,对于找不到S的样本最后去除,对于找不到P的样本保留,同样作为负样本。
def proceed_data(text_list,spo_list,p2id,id2p,tokenizer,MAX_LEN):
id_label = {
}
ct = len(text_list)
MAX_LEN = MAX_LEN
input_ids = np.zeros((ct,MAX_LEN),dtype='int32')
attention_mask = np.zeros((ct,MAX_LEN),dtype='int32')
start_tokens = np.zeros((ct,MAX_LEN),dtype='int32')
end_tokens = np.zeros((ct,MAX_LEN),dtype='int32')
send_s_po = np.zeros((ct,2),dtype='int32')
object_start_tokens = np.zeros((ct,MAX_LEN,len(p2id)),dtype='int32')
object_end_tokens = np.zeros((ct,MAX_LEN,len(p2id)),dtype='int32')
invalid_index = []
for k in range(ct):
context_k = text_list[k].lower().replace(' ','')
enc_context = tokenizer.encode(context_k,max_length=MAX_LEN,truncation=True)
if len(spo_list[k])==0:
invalid_index.append(k)
continue
start = []
end = []
S_index = []
for j in range(len(spo_list[k])):
answers_text_k = spo_list[k][j]['subject'].lower().replace(' ','')
chars = np.zeros((len(context_k)))
index = context_k.find(answers_text_k)
chars[index:index+len(answers_text_k)]=1
offsets = []
idx=0
for t in enc_context[1:]:
w = tokenizer.decode([t])
if '#' in w and len(w)>1:
w = w.replace('#','')
if w == '[UNK]':
w = '。'