import os
import sys
import json
def read_by_lines(path):
result = list()
with open(path, "r", encoding="utf8") as infile:
for line in infile:
result.append(line.strip())
return result
def write_by_lines(path, data):
with open(path, "w", encoding="utf8") as outfile:
[outfile.write(d + "\n") for d in data]
def extract_result(text, labels):
ret, is_start, cur_type = [], False, None
if len(text) != len(labels):
labels = labels[:len(text)]#如果labels长,就截断
for i, label in enumerate(labels):
if label != u"O":
_type = label[2:]#切取当前标签表示事件类型的部分
if label.startswith(u"B-"):#如果标签以B-开头
is_start = True#设置标记
cur_type = _type# cur_type是外面的变量
#ret是保存触发词或者论元位置,内容的集合
ret.append({"start": i, "text": [text[i]], "type": _type})
elif _type != cur_type:#如果I-后边跟的和之前的一样,是进不来这里的
cur_type = _type
is_start = True
ret.append({"start": i, "text": [text[i]], "type": _type})
elif is_start:#标记未变,表示是同一个论元,-1取的当前的论元,追加的真实文本字符
ret[-1]["text"].append(text[i])
else:
cur_type = None
is_start = False
else:
cur_type = None#如果是非实体,当前类型None,is_start:False
is_start = False
return ret
trigger_file='./ckpt/DuEE1.0/trigger/test_pred.json'
role_file='./ckpt/DuEE1.0/role/test_pred.json'
schema_file='./datasets/DuEE_1_0/event_schema.json'
save_path='./submit/test_duee_1.json'
pred_ret = []#保存预测结果
trigger_data = read_by_lines(trigger_file)#读取句子级触发词预测
role_data = read_by_lines(role_file)#读取句子论元角色预测
schema_data = read_by_lines(schema_file)#schema,规则
print("trigger predict {} load from {}".format(len(trigger_data),
trigger_file))
print("role predict {} load from {}".format(len(role_data), role_file))
print("schema {} load from {}".format(len(schema_data), schema_file))
schema = {}#特定的事件类型只能对应特定的论元角色
for s in schema_data:
d_json = json.loads(s)#一个json字典对象
#建立事件类型和角色列表之间的对应关系
schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]]
# process the role data
sent_role_mapping = {}
for d in role_data:
d_json = json.loads(d)
r_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
role_ret = {}#论元
for r in r_ret:
role_type = r["type"]#角色类型
if role_type not in role_ret.keys():#没有就新建立映射
role_ret[role_type] = []
role_ret[role_type].append("".join(r["text"]))#把提取的标签文本加进去
sent_role_mapping[d_json["id"]] = role_ret#句子id到角色论元对的映射
sent_role_mapping['20abb32653a37fc1d7c75e334fe6a924'].items()
for d in trigger_data:
d_json = json.loads(d)
t_ret = extract_result(d_json["text"], d_json["pred"]["labels"])#提取触发词位置,文本
#遍历当前样本的所有事件类型
pred_event_types = list(set([t["type"] for t in t_ret]))
event_list = []
for event_type in pred_event_types:#遍历所有事件类型
role_list = schema[event_type]#获取事件类型对应的论元角色列表
arguments = []
#遍历当前句子id对应的论元角色类型和论元
for role_type, ags in sent_role_mapping[d_json["id"]].items():
if role_type not in role_list:#过滤掉不在角色列表里的,因为必须和事件对应
continue
for arg in ags:
if len(arg) == 1:#过滤掉只有一个论元字符串的
continue
#把角色论元对加进列表
arguments.append({"role": role_type, "argument": arg})
#当前事件包括事件类型 和里面的论元对
event = {"event_type": event_type, "arguments": arguments}
event_list.append(event)#这还是 一个样本对应的事件列表
pred_ret.append({#
"id": d_json["id"],
"text": d_json["text"],
"event_list": event_list
})
pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret]
print("submit data {} save to {}".format(len(pred_ret), save_path))
write_by_lines(save_path, pred_ret)