本文以项目readme.md训练逻辑的顺序解读
1.下载BERT预训练模型
更多bert模型参考github地址
本文用的是BERT-Base, Cased(12-layer, 768-hidden, 12-heads , 110M parameters)下载地址。其中Cased表示保留真实的大小写和重音标记符,uncased表示文本在单词标记之前就已经变为小写,也去掉了任何重音标记例如,John Smith变成john smith。通常,Uncased模型会更好,除非大小写信息对于我们的任务很重要,如命名实体识别或词性标记。
2.以三元组的形式构造数据集
以NYT数据集为例:NYT数据集是关于远程监督关系抽取任务的广泛使用的数据集。该数据集是通过将freebase中的关系与纽约时报(NYT)语料库对齐而生成的。纽约时报New York Times数据集包含150篇来自纽约时报的商业文章。抓取了从2009年11月到2010年1月纽约时报网站上的所有文章。在句子拆分和标记化之后,使用斯坦福NER标记器来标识PER和ORG从每个句子中的命名实体。对于包含多个标记的命名实体,我们将它们连接成单个标记。然后,我们将同一句子中出现的每一对(PER,ORG)实体作为单个候选关系实例,PER实体被视为ARG-1,ORG实体被视为ARG-2。
2.1运行过程
- NYT数据集下载地址
- 运行CasRel/data/NYT/raw_NYT/generate.py将数字编码形式的nyt数据集转换为字符形式的数据集,并根据三元组将数据集分类为normal,epo,spo几种类型。运行结果为CasRel/data/NYT/new_train.json,new_train_epo.json,new_train_normal.json,new_train_seo.json。将产生的新文件移至NYT/目录下,将new_train.json改名为train.json,test与valid同理。
- 运行CasRel/data/NYT/build_data.py处理数据得到三元组,按train,dev,test分类。
- 将test.json移至test_split_by_num目录中,运行split_by_num.py将test集按每个句子含有的三元组数量分类;将test_epo.json,test_normal.json,test_seo.json移至test_split_by_type目录中,将test集按类型分类得到三元组文件。
2.2重要的代码
generate.py
# 将raw_NYT\train.json中的数字形式生成训练集的文本形式
def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file):
with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, \
open(epo_file,'w') as f4, open(seo_file, 'w') as f5:
seo_file, 'w') as f5:
cnt_normal = 0
cnt_epo = 0
cnt_seo = 0
lines = f1.readlines() # readlines()方法用于读取所有行(直到结束符EOF)并返回列表
for line in lines:
line = json.loads(line)
print(len(line))
lengths, sents, spos = line[0], line[1], line[2]
print(len(spos))
print(len(sents))
for i in range(len(sents)):
new_line = dict()
# print(sents[i])
# print(spos[i])
tokens = [word_dict[i] for i in sents[i]] # tokens为sents对应数字形式的字符串数组
sent = ' '.join(tokens) # 以空格形式连接字符串数组生成一个新的字符串
new_line['sentText'] = sent # new_line为包含三元组的字典
triples = np.reshape(spos[i], (-1, 3)) # 将spo[i]关系三元组的维度变为3列
relationMentions = []
for triple in triples:
rel = dict()
rel['em1Text'] = tokens[triple[0]]
rel['em2Text'] = tokens[triple[1]]
rel['label'] = rel_dict[triple[2]]
relationMentions.append(rel)
new_line['relationMentions'] = relationMentions
f2.write(json.dumps(new_line) + '\n')
if is_normal_triple(spos[i]):
f3.write(json.dumps(new_line) + '\n')
if is_multi_label(spos[i]):
f4.write(json.dumps(new_line) + '\n')
if is_over_lapping(spos[i]):
f5.write(json.dumps(new_line) + '\n')
build_data.py
# 读取数据集文件,将文本、三元组分类存储
with open('train.json') as f:
for l in tqdm(f): # tqdm是可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)
a = json.loads(l)
if not a['relationMentions']: # 若某个句子a中关系'relationMentions'为空,跳过之
continue
# 提取出每个句子及其三元组
line = {
'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), #去除'sentText'中的\r(回车)、\n(换行)、两头的'\'
'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if
i['label'] != 'None']
}
if not line['triple_list']:
continue
# 将提取出来的句子及其三元组信息加入到训练集数据train_data中,将三元组中的关系加入到集合rel_set中(无序不重复元素序列)
train_data.append(line)
for rm in a['relationMentions']:
if rm['label'] != 'None':
rel_set.add(rm['label'])
3.指定实验设置
run.py中的默认参数:
{
"bert_model": "cased_L-12_H-768_A-12",
"max_len": 100,
"learning_rate": 1e-5,
"batch_size": 6,
"epoch_num": 100,
}
根据自己的设置修改
4.训练模型并评估
确定运行方式,使用的数据集
python run.py ---train=True --dataset=NYT
在测试集上评估
python run.py --dataset=NYT