spo实体关系抽取、属性抽取

7 篇文章 0 订阅

这里面针对的是一个头实体和一个尾实体一一对应的情况下,不存在一个头实体对应多个尾实体交叉的情况。

一个字符,针对同一个关系,只能被标注一次。这个关系再次出现的时候,顺延标注其他字符。
字符标签就是针对每个字符,属于的是某一个关系类别的头实体或者尾实体,或者中间值I或者其他值O,一共2N+2类别的多标签的分类,对于存在多个相同P,的独立的S和O,只要S和O是一一对应的,可以采用这样的标注策略,一个关系标注一个S和O即可,再来同一个关系,如果之前这个关系下的已经标注过了,往后推标注其他字符,这样标注出来才是意义对应的关系。
然后解码的时候,保持原来的顺序一次解码,当一个关系含有多个头实体或者尾实体时,一次性全部取出来,然后一一映射对应即可。

属性抽取用"Text"来定义属性的类别

    {"text": "赵允弼(1007—1069),宋太宗赵炅之孙,镇恭懿王赵元偓之子", 
    "spo_list": [{"predicate": "朝代", "object_type": {"@value": "Text"}, 
    "subject_type": "历史人物", "object": {"@value": "宋"}, "subject": "赵元偓"}]}

训练编码

    1、转换成多标签的分类,每个字符对应有2(N-2)+2个类别,相应的位置置一来标记属于当前类别
    2、label前面的类别索引是头实体的ID也是,类别对应的ID,后面一半是尾实体的ID,对应关系:尾实体ID=头实体ID+N-2
    3、构造原始数据的原始数据,因为分词构建输入词典不一定是一个字符一个输入token,跟分词方式有关
    4、对于存在一个关系多个不同性质的尾实体而增加的关系类别特殊对待

预测解码

    1、找到label为1的地方,找到对应的索引,根据关系ID,头实体ID,尾实体ID得出相应的实体信息
    2、根据相应的label词典,predicate2id.json和id2spo.json得到相应的类别类型信息

spo准确率判断

    1、去除不合规范的spo,去除多余的spo,spo是否相同不是简单的包含判断,去除书名号,判断包含关系
    2、根据统计结果计算相应的,p,r,f

修复问题

1 避免一次性加载所有数据做评估导致内存溢出,改为分批次评估再累加
2 utils文件的write_prediction_results的bug
.zip文件必须先保存上面的文件,再保存zip文件,否则每一次.zip都是上一次的预测结果
2 utils文件的get_precision_recall_f1的bug
f1为1.0的时候匹配出来为0,改成正则直接匹配

上传文件受限

cat model_best.* > model_best_all.zip
unzip model_best_all.zip


def parse_label(spo_list, label_map, tokens, tokenizer):
    # 2 tags for each predicate + I tag + O tag
    num_predict = len(label_map.keys()) - 2
    num_labels = 2 * num_predict + 2
    seq_len = len(tokens)
    # initialize tag
    labels = [[0] * num_labels for i in range(seq_len)]
    #  find all entities and tag them with corresponding "B"/"I" labels
    for spo in spo_list:
        for spo_object in spo['object'].keys():
            # assign relation label
            if spo['predicate'] in label_map.keys():
                # simple relation
                label_subject = label_map[spo['predicate']]  # 关系的头实体,刚好相差关系的数量
                label_object = label_subject + num_predict  # 关系的尾实体,刚好相差关系的数量
                subject_tokens = tokenizer._tokenize(spo['subject'])
                object_tokens = tokenizer._tokenize(spo['object']['@value'])
            else:
                # complex relation
                label_subject = label_map[spo['predicate'] + '_' + spo_object]  # 存在一个头实体(关系)对应多个尾实体的情况
                label_object = label_subject + num_predict
                subject_tokens = tokenizer._tokenize(spo['subject'])
                object_tokens = tokenizer._tokenize(spo['object'][spo_object])

            subject_tokens_len = len(subject_tokens)
            object_tokens_len = len(object_tokens)

            # assign token label
            # there are situations where s entity and o entity might overlap, e.g. xyz established xyz corporation
            # to prevent single token from being labeled into two different entity
            # we tag the longer entity first, then match the shorter entity within the rest text
            forbidden_index = None
            if subject_tokens_len > object_tokens_len:
                for index in range(seq_len - subject_tokens_len + 1):
                    if tokens[index:index + subject_tokens_len] == subject_tokens and labels[index][label_subject] == 0:
                        labels[index][label_subject] = 1
                        for i in range(subject_tokens_len - 1):
                            labels[index + i + 1][1] = 1
                        forbidden_index = index
                        break

                for index in range(seq_len - object_tokens_len + 1):
                    if tokens[index:index + object_tokens_len] == object_tokens and labels[index][label_object] == 0:
                        if forbidden_index is None:
                            labels[index][label_object] = 1
                            for i in range(object_tokens_len - 1):
                                labels[index + i + 1][1] = 1
                            break
                        # check if labeled already
                        elif index < forbidden_index or index >= forbidden_index + len(
                                subject_tokens):
                            labels[index][label_object] = 1
                            for i in range(object_tokens_len - 1):
                                labels[index + i + 1][1] = 1
                            break

            else:
                for index in range(seq_len - object_tokens_len + 1):
                    if tokens[index:index + object_tokens_len] == object_tokens and labels[index][label_object] == 0:
                        labels[index][label_object] = 1
                        for i in range(object_tokens_len - 1):
                            labels[index + i + 1][1] = 1
                        forbidden_index = index
                        break

                for index in range(seq_len - subject_tokens_len + 1):
                    if tokens[index:index + subject_tokens_len] == subject_tokens and labels[index][label_subject] == 0:
                        if forbidden_index is None:
                            labels[index][label_subject] = 1
                            for i in range(subject_tokens_len - 1):
                                labels[index + i + 1][1] = 1
                            break
                        elif index < forbidden_index or index >= forbidden_index + len(
                                object_tokens):
                            labels[index][label_subject] = 1
                            for i in range(subject_tokens_len - 1):
                                labels[index + i + 1][1] = 1
                            break


def decoding(example_batch,
             id2spo,
             logits_batch,
             seq_len_batch,
             tok_to_orig_start_index_batch,
             tok_to_orig_end_index_batch):
    """
    model output logits -> formatted spo (as in data set file)
    """
    num_predict = len(id2spo['predicate']) - 2
    formatted_outputs = []
    for (i, (example, logits, seq_len, tok_to_orig_start_index, tok_to_orig_end_index)) in \
            enumerate(zip(example_batch, logits_batch, seq_len_batch, tok_to_orig_start_index_batch,
                          tok_to_orig_end_index_batch)):

        logits = logits[1:seq_len + 1]  # slice between [CLS] and [SEP] to get valid logits
        logits[logits >= 0.5] = 1
        logits[logits < 0.5] = 0
        tok_to_orig_start_index = tok_to_orig_start_index[1:seq_len + 1]
        tok_to_orig_end_index = tok_to_orig_end_index[1:seq_len + 1]
        predictions = []
        for token in logits:
            predictions.append(np.argwhere(token == 1).tolist())  # 返回对应的索引值

        # format predictions into example-style output
        formatted_instance = {}
        text_raw = example['text']

        # flatten predictions then retrival all valid subject id
        flatten_predictions = []
        for layer_1 in predictions:
            for layer_2 in layer_1:
                flatten_predictions.append(layer_2[0])
        subject_id_list = []
        flatten_predictions_set = list(set(flatten_predictions))
        flatten_predictions_set.sort(key=flatten_predictions.index)
        for cls_label in flatten_predictions_set:  # 这个多标签的分类类别,前面的55个类别是头实体B,后面的是尾实体B
            if 1 < cls_label <= num_predict + 1 and (cls_label + num_predict) in flatten_predictions_set:
                subject_id_list.append(cls_label)  # 类别头实体的id也刚好和关系类别的id重合
        subject_id_list_set = list(set(subject_id_list))
        subject_id_list_set.sort(key=subject_id_list.index)
        # fetch all valid spo by subject id
        spo_list = []
        for id_ in subject_id_list_set:
            if id_ in complex_relation_affi_label:
                continue  # do this in the next "else" branch
            if id_ not in complex_relation_label:
                subjects = find_entity(text_raw, id_, predictions,
                                       tok_to_orig_start_index,
                                       tok_to_orig_end_index)
                objects = find_entity(text_raw, id_ + num_predict, predictions,
                                      tok_to_orig_start_index,
                                      tok_to_orig_end_index)
                for subject_, object_ in zip(subjects, objects):
                    spo_list.append({
                        "predicate": id2spo['predicate'][id_],
                        "object_type": {
                            '@value': id2spo['object_type'][id_]
                        },
                        'subject_type': id2spo['subject_type'][id_],
                        "object": {
                            '@value': object_
                        },
                        "subject": subject_
                    })
            else:
                #  traverse all complex relation and look through their corresponding affiliated objects
                subjects = find_entity(text_raw, id_, predictions,
                                       tok_to_orig_start_index,
                                       tok_to_orig_end_index)
                objects = find_entity(text_raw, id_ + num_predict, predictions,
                                      tok_to_orig_start_index,
                                      tok_to_orig_end_index)
                for subject_ in subjects:
                    for object_ in objects:
                        object_dict = {'@value': object_}
                        object_type_dict = {
                            '@value': id2spo['object_type'][id_].split('_')[0]
                        }
                        if id_ in two_complex_relation and id_ + 1 in subject_id_list_set:
                            id_affi = id_ + 1  # 连续一个类别对应两个个尾实体的一类,只要出现一个就把暗含的其他的一起找出来
                            object_dict[id2spo['object_type'][id_affi].split(
                                '_')[1]] = find_entity(text_raw, id_affi + num_predict,
                                                       predictions,
                                                       tok_to_orig_start_index,
                                                       tok_to_orig_end_index)[0]
                            object_type_dict[id2spo['object_type'][
                                id_affi].split('_')[1]] = id2spo['object_type'][
                                    id_affi].split('_')[0]
                        elif id_ == four_complex_relation[0]:  # 连续一个类别对应四个个尾实体的一类
                            for id_affi in four_complex_relation[1:]:
                                if id_affi in subject_id_list_set:
                                    object_dict[id2spo['object_type'][id_affi].split('_')[1]] = \
                                        find_entity(text_raw, id_affi + num_predict, predictions,
                                                    tok_to_orig_start_index, tok_to_orig_end_index)[0]
                                    object_type_dict[id2spo['object_type'][id_affi].split('_')[1]] = \
                                        id2spo['object_type'][id_affi].split('_')[0]
                        spo_list.append({
                            "predicate": id2spo['predicate'][id_],
                            "object_type": object_type_dict,
                            "subject_type": id2spo['subject_type'][id_],
                            "object": object_dict,
                            "subject": subject_
                        })

        formatted_instance['text'] = example['text']
        formatted_instance['spo_list'] = spo_list
        formatted_outputs.append(formatted_instance)
    return formatted_outputs
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值