NPL工具——NER任务的多模型投票器

0. 介绍

最近在做命名实体识别(NER)相关的任务,在做一个集成的模型,涉及到多个模型结果融合的问题,需要用某种方法把多个模型预测出来的结果进行投票,得出最终的结果。由于任务是flat的NER,所以在投票的过程中需要避免实体重叠的问题。

为了实现这个功能我写了一个投票器类,把它记录下来,方便以后需要的时候再次使用。

1.数据格式

假设所有k个模型预测出来的结果保存为list格式的result,result的长度即为k,每一个元素对应一个dict,记录模型的预测结果,dict的键为类别名称,值为所有检测为该类的实体。

result = [{'类别1': [],
  '类别2': [],
  '类别3': [[25, 31]],
  '类别4': [[118, 123]],
  '类别5': [[70, 71], [94, 99]],
  '类别6': []},
 {'类别1': [[182, 183]],
  '类别2': [],
  '类别3': [[25, 31], [44, 52], [79, 92]],
  '类别4': [[118, 123]],
  '类别5': [[70, 71], [94, 99]],
  '类别6': []},
  ……
 {'类别1': [],
  '类别2': [],
  '类别3': [[25, 31], [44, 52]],
  '类别4': [[118, 123]],
  '类别5': [[44, 52], [70, 71], [96, 99]],
  '类别6': []}]

2.投票规则

首先回顾一下一般的分类任务中,bagging的策略是如何进行的,最简单的规则就是少数服从多数的规则,例如10个模型中,如果有8个将它分为A类,两个分为B类,那么最终结果就判定为A类,但是在NER任务中,由于涉及到实体的区间(span),便没有办法只采用简单的投票法将实体标出,因为可能某一个位置附近确定出现有一个实体,但是还需要判断①这个位置的实体的起始位置,②这个位置的实体所属的类别。

例如,某句话中,模型1将“粉色海星派大星”识别为人物类,模型2将“海星派大星”识别为人物类,模型3将“粉色海星”识别为人物类,那最终投票的结果又该如何判定呢?

于是我设计了一种投票的规则,规则或许仍然存在不合理的地方,但可以输出一个逻辑完整的,较为可靠的结果。

规则&流程
1.生成初始化:读取所有模型的结果results,遍历其中识别到的每一个实体(不论类型),将所有的开始和结束位置记录下来,生成一个初始化的计数‘字典’,计数‘字典’的键为这个位置,值为这个位置作为开始或者结束位置出现的次数。由于在python中dict对象在迭代中是不可变的,所以用一个list来模拟这个‘字典’,list的index模拟‘字典’的键,然后建立一个从index到位置的映射就可以了。
2.统计出现次数:再次读取results,对初始化计数‘字典’中出现的所有位置,记录这个位置在所有模型中作为所有类型的起始和结束位置出现过的次数(后来这个次数改成了加权,权重为每个模型的f1的值),填到‘字典’的值上,至此‘字典’的每个位置上对应的都是一个p*2的array,p是实体类别的数量。
3.寻找第一显著位:在上面生成的计数‘字典’中,寻找第一显著位置,如果大于‘显著阈值’就去匹配与它相对应的开始或结束位置。如果第一显著位置是start位,则向右去寻找这个实体的end位;如果是end位,则向左去寻找这个实体的start位。找到第一显著位置之后,将计数‘字典’的这个位置的数值置为0.
4.匹配第一显著位:以向右寻找end位为例,说明匹配规则。这个匹配位置应当满足:(1)生成的span不能与已有的span重叠;(2)匹配位置应当是所有该类型(与3中找到的第一显著位同类)中,最显著的位置;(3)匹配位置的计数值满足‘显著阈值’。匹配成功后,将匹配位置在计数‘字典’中的计数值置为0,并将新生成的实体span添加到已有span中去。
5.循环:继续执行3和4两步,在剩下的位置中寻找第一显著位并匹配出实体,直到第一显著位的显著程度小于设定的显著阈值,则跳出循环。

3.代码实现

import numpy as np
import copy

class Voter():
    def __init__(self, threshold, results):
        self.threshold = threshold   #  显著阈值
        self.results = results         #  所有模型的结果
        self.spans = []              #  现有实体的所有span
    
    
    def predicate2id(self, predicate):
        pr2id = {'类别1':0, '类别2':1, '类别3':2, '类别4':3}
        return pr2id[predicate]
    
    
    def id2predicate(self, id):
        id2pr = {0:'类别1', 1:'类别2', 2:'类别3', 3:'类别4'}
        return id2pr[id]
    
    
    def model_point(self, model_id):
    	'''
    	这里记录的是所有模型的f1的值,作为权重,注意修改
    	'''
        point = [0.6153846153847338, 0.6177606177607161, 0.6169014084508121, 0.5877318116976925, 0.573333333333447,
                0.6627043090639932, 0.630225080385971, 0.6635514018692636, 0.6210720887247242]
        return point[model_id]
    
    def sub_of(self, sub_inter, inter):
        '''
        辅助工具:判断一个区间是不是另一个区间的子区间
        '''
        a1, a2 = sub_inter[0], sub_inter[1]
        # print(a1)
        # print(a2)
        if a1 > a2:
            return False
        if len(inter):
            b1, b2 = inter[0], inter[1]
            assert b1 < b2
        else:
            b1, b2 = 0, 0
        if a1 >= b1 and a2 <= b2:
            return True
        else:
            return False
        
        
    def find_all_spans_by_cls(self, cls):
        '''
        辅助工具:获取所有模型中某类别所有实体对应区间
        '''
        all_spans_by_cls = []
        results = self.results
        for result in results:   # 对每一个模型的结果
            for span in result[self.id2predicate(cls)]:   # 对当前模型结果中这一类的所有span
                if span not in all_spans_by_cls:   # 如果不在已经选出来的span中
                    all_spans_by_cls.append(span)
        return all_spans_by_cls
    
    
    def generate_init(self):
        '''
        生成初始化字典
        由于字典在迭代过程中不能改变其中数值
        所以将计数的存储方式改为list
        并建立一个从position到list的index的映射,模拟字典的key
        '''
        count_dict = []
        key2index = {}   # 这两个映射一旦生成了就不用在动它了
        index2key = {}
        i = 0
        for model_res in self.results:
            for key in model_res:   # 对每一类
                # print(model_res[key])   # 每一类对应的实体
                for v in model_res[key]:    # 每一类对应的每一个实体
                    # print(v)
                    for vv in v:             # 每一类对应的每一个实体对应的start和end
                        # print(vv)
                        # print(count_dict)
                        if str(vv) not in key2index.keys():
                            key2index[str(vv)] = i
                            index2key[i] = str(vv)
                            count_dict.append(np.zeros((4,2)))
                            i += 1
        return count_dict, key2index, index2key
        

    def fill_count(self):
        '''
        每个位置计数
        '''
        count_dict, key2index, index2key = self.generate_init()
        for model_id, model_res in enumerate(self.results):
            for key in model_res:
                for v in model_res[key]:  # v 是每一个实体对应的start和end的list
                    if key == '试验要素':
                        count_dict[key2index[str(v[0])]][0][0] += self.model_point(model_id)  # v的start位置的第一行第一列  代表试验要素的开始
                        count_dict[key2index[str(v[1])]][0][1] += self.model_point(model_id)  # v的end位置的第一行第二列  代表试验要素的结束
                    elif key == '性能指标':
                        count_dict[key2index[str(v[0])]][1][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][1][1] += self.model_point(model_id)
                    elif key == '任务场景':
                        count_dict[key2index[str(v[0])]][2][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][2][1] += self.model_point(model_id)
                    elif key == '系统组成':
                        count_dict[key2index[str(v[0])]][3][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][3][1] += self.model_point(model_id)
        return count_dict
    
    
    def search_first(self, count_dict, key2index, index2key):
        '''
        寻找count_dict中出现次数最多的位置
        返回其是start还是end,其分类码,以及其对应数值
        并在count_dict中将这个位置置为0
        '''
        print('searching first...')
        
        max_pos = 0   # 当前最大计数对应位置
        max_count = 0  # 当前最大计数

        for i in range(len(count_dict)):
            pos = index2key[i]
            cur_count = np.max(count_dict[i])
            if cur_count > max_count:
                mx = np.where(count_dict[i] == cur_count)
                cls = int(mx[0])        # 对应类别编号
                se = int(mx[1])         # 对应开始结束
                max_pos = pos
                max_count = cur_count
                
        print('got max_pos: %s' %max_pos)
        print('current max_count is %s' % max_count)
        # print('remove pos: %s' %max_pos)
        count_dict[key2index[max_pos]] = np.zeros((4,2))  # 这个位置置为0
        return se, cls, int(max_pos), count_dict, max_count

    
    def search_backward(self, cls, base_pos, count_dict, spans, key2index, index2key):
        '''
        当search_first函数搜索到的是se为1(end),则向后找start
        cls:search_first搜索到的cls
        base_pos:基准位置
        返回:搜索到的最匹配位置
        '''
        print('----------')
        print('searching backward...')
        max_pos = -1
        max_count = 0
        base_pos = int(base_pos)
        print('match for pos: %s' %base_pos)
        # print(spans)
        span_to_append = []
        
        for i in range(len(count_dict)):
            '''
            规则:
            1.所选点在base之前
            2.所选点在潜在点集中(已满足)
            3.所选点与base之间所有点都在至少一个模型的实体结果中
            4.所选点在上一个同类span的end之后(当前span不是第一个时,才判断规则4)
            '''
            pos = index2key[i]
            
            # tmp_span用于判断base在已有span中的位置
            tmp_span = copy.copy(spans)
            if [base_pos, base_pos] not in tmp_span:
                tmp_span.append([base_pos, base_pos])
            # print([base_pos,base_pos])
            # print(tmp_span)
            tmp_span.sort()
            
            # 开始对规则3进行判断
            all_spans_by_cls = self.find_all_spans_by_cls(cls)
            prncp3 = False
            for span in all_spans_by_cls:     # 对每一个同类实体,判断所选区间是不是其子集
                prncp3 = prncp3 or self.sub_of([int(pos), base_pos], span)
             
            if len(spans):       # 如果spans这个时候已经是非空的
                # print('base_pos 在tmp_span中前边紧接着的span:%s' %(tmp_span[tmp_span.index([base_pos, base_pos])-1]))
                
                if tmp_span.index([base_pos, base_pos]) == 0:
                    # 如果base_pos在tmp_span中已经是第一个,前面没有了,那么就可以往前随便选
                    if int(pos) < base_pos and prncp3:
                        cur_count = count_dict[i][cls][0]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
                elif tmp_span.index([base_pos, base_pos]) > 0:
                    # 如果base在tmp中不是第一个,前面还有,那么需要保证找的匹配点在前面一个span之后(prncp4)
                    prncp4 = tmp_span[tmp_span.index([base_pos, base_pos])-1][1] < int(pos)
                    if int(pos) < base_pos and prncp3 and prncp4:   # 向前搜索,并且不在已有的span中
                        cur_count = count_dict[i][cls][0]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
            else:                                   # 初始情况下spans为空,不需要判断在不在已有的span中
                if int(pos) < base_pos and prncp3:
                    cur_count = count_dict[i][cls][0]
                    if cur_count > max_count:
                        max_count = cur_count
                        max_pos = int(pos)
                        # print(max_pos)
        if max_pos >= 0:
            print('got max_pos at %s' % max_pos)
            count_dict[key2index[str(max_pos)]] = np.zeros((4,2))   # 置为0
            # print('remove pos: %s' % max_pos)
            span_to_append = [max_pos, base_pos]   # 准备追加的span
            # print(span_to_append)
                    
        if span_to_append not in spans and len(span_to_append):
            print('doing backward append...')
            if len(spans):
                spans.sort()
                for span in spans:
                    if span[0] == span_to_append[1]+1 and span != span_to_append:      # 跟下一个span连起来了
                        span_to_append = [span_to_append[0], span[1]]   # 取首尾,中间不要
                        spans.append(span_to_append)
                        spans.remove(span)                             # 原来的删掉
                    elif span[1] == span_to_append[0]-1 and span != span_to_append:     # 跟上一个span连起来了
                        span_to_append = [span[0], span_to_append[1]]    # 取首尾
                        spans.append(span_to_append)
                        spans.remove(span)
                    else:
                        if span != span_to_append:
                            spans.append(span_to_append)                  # 没有接起来的情况,直接append
            elif len(spans) == 0:
                spans.append(span_to_append)
        # print('spans after searched backward: %s' % spans)
        return int(max_pos), count_dict, spans

    
    def search_forward(self, cls, base_pos, count_dict, spans, key2index, index2key):
        '''
        当search_first函数搜索到的是se为0(start),则向前找end
        cls:search_first搜索到的cls
        base_pos:基准位置
        返回:搜索到的最匹配位置
        '''
        print('----------')
        print('searching forward...')
        max_pos = -1
        max_count = 0
        base_pos = int(base_pos)
        # print(spans)
        print('match for pos: %s' %base_pos)
        span_to_append = []
        
        for i in range(len(count_dict)):
            '''
            规则:
            1.所选点在base之后
            2.所选点在潜在点集中(已满足)
            3.所选点与base之间所有点都在至少一个模型的实体结果中
            4.所选点在下一个同类span的start之前(当前span不是最后一个时,才判断规则4)
            '''
            pos = index2key[i]  # 找出所有潜在的pos,str类型,并对每一个pos进行循环
            
            tmp_span = copy.copy(spans)      # 复制一个spans,并把当前位置加进去,以寻找其相邻的span
            if [base_pos, base_pos] not in tmp_span:
                tmp_span.append([base_pos, base_pos])
            # print(spans)
            # print([base_pos,base_pos])
            # print(tmp_span)
            tmp_span.sort()
            
            # 开始对规则3进行判断
            all_spans_by_cls = self.find_all_spans_by_cls(cls)
            prncp3 = False
            for span in all_spans_by_cls:     # 对每一个同类实体,判断所选区间是不是其子集
                prncp3 = prncp3 or self.sub_of([base_pos, int(pos)], span)
            
            if len(spans):       # 如果spans这个时候已经是非空的
                # print(spans)
                # print('tmp_span:%s' %tmp_span)
                # print(tmp_span.index([base_pos, base_pos]))
                # print(len(tmp_span))
                if tmp_span.index([base_pos, base_pos])+1 == len(tmp_span):   
                    # base_pos是tmp_span中的最后一个,后边没有了,那么后面的所有点都可选
                    # print('后面没有了')
                    if int(pos) > base_pos and prncp3:
                        cur_count = count_dict[i][cls][1]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
                elif tmp_span.index([base_pos, base_pos])+1 < len(tmp_span):
                    # 如果base_pos后面还有别的实体,那么只能选到这个实体之前
                    # print('base_pos 在tmp_span中后边紧接着的span:%s' %(tmp_span[tmp_span.index([base_pos, base_pos])+1]))
                    prncp4 = tmp_span[tmp_span.index([base_pos, base_pos])+1][0] > int(pos)
                    if int(pos) > base_pos and prncp3 and prncp4:   # 向前搜索,并且不在已有的span中
                        cur_count = count_dict[i][cls][1]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
            else:                                   # 初始情况下spans为空,不需要判断在不在已有的span中
                if int(pos) > base_pos and prncp3:
                    cur_count = count_dict[i][cls][1]
                    if cur_count > max_count:
                        max_count = cur_count
                        max_pos = int(pos)
                        # print(max_pos)
        if max_pos >= 0:
            print('got max_pos at %s' % max_pos)
            count_dict[key2index[str(max_pos)]] = np.zeros((4,2))
            # print('remove pos: %s' % max_pos)
            span_to_append = [base_pos, max_pos]
            # print(span_to_append)
                    
        if span_to_append not in spans and len(span_to_append): # 如果准备追加的不在原有spans中
            if len(spans):   # 如果spans已有内容
                print('doing backward append...')
                spans.sort()
                for span in spans:
                    if span[0] == span_to_append[1]+1 and span != span_to_append:     # 跟下一个span连起来了
                        span_to_append = [span_to_append[0], span[1]]   # 取首尾,中间不要
                        spans.append(span_to_append)
                        spans.remove(span)                             # 原来的删掉
                    elif span[1] == span_to_append[0]-1 and span != span_to_append:     # 跟上一个span连起来了
                        span_to_append = [span[0], span_to_append[1]]    # 取首尾
                        spans.append(span_to_append)
                        spans.remove(span)
                    else:
                        if span != span_to_append:
                            spans.append(span_to_append)                  # 没有接起来的情况,直接append
            elif len(spans) == 0:      # 如果现在spans还没有内容,但是有内容可以加入
                spans.append(span_to_append)
        # print('spans after searched forward: %s' % spans)
        return int(max_pos), count_dict, spans
    
    
    def generate_res(self):
        '''
        生成最终的结果
        '''
        res = {'类别1':[], '类别2':[], '类别3':[], '类别4':[]}
        spans = self.spans
        threshold = self.threshold
        print('=======================')
        print('set threshold: %s' % threshold)
        print('=======================')
        _, key2index, index2key = self.generate_init()  # 只是为了保存两个dict
        count_dict = self.fill_count()  # 初始化
        
        while True:                   # 满足阈值条件时,一直执行,不满足时,跳出
            # cur_se, cur_cls, cur_pos, self.count_dict, max_count = self.search_first(count_dict, key2index, index2key)
            try:
                cur_se, cur_cls, cur_pos, self.count_dict, max_count = self.search_first(count_dict, key2index, index2key)
            except Exception as e:
                print(e)
                break
            if max_count < threshold:
                break
            if cur_se == 0:    # 如果找到的是一个start,接下来就找它对应的end
                cur_end, count_dict, spans = self.search_forward(cls=cur_cls, base_pos=cur_pos, count_dict=count_dict, spans=spans, key2index=key2index, index2key=index2key)
                if cur_end != -1:
                    res[self.id2predicate(cur_cls)].append([cur_pos, cur_end])       # 保存结果,最终保存的不是spans而是res
            elif cur_se == 1:    # 如果找到的是一个end,接下来就找它对应的start
                cur_start, count_dict, spans = self.search_backward(cls=cur_cls, base_pos=cur_pos, count_dict=count_dict, spans=spans, key2index=key2index, index2key=index2key)
                if cur_start != -1:
                    res[self.id2predicate(cur_cls)].append([cur_start, cur_pos])
            
        return res
            

3.使用方法

首先注意修改model_point函数中对应的f1的分数,然后注意类别数量和名称要与自己的数据集对应。
还有np.zeros生成的array的维度要与自己的类别数量对应上。

V = Voter(threshold, results)
final_res = V.generate_res()

4.其它情况

这种投票规则会出现一种情况没有办法解决,就是search_first寻找第一显著位的时候,如果两个位置具有相同的显著计数,则代码无法继续进行,当遇到这种情况我是单独用f1值最大的结果作为最终结果的。

这篇博客主要是写给我自己看的,如果你有其他的更好的投票方法,或者认为我的方法有明显的BUG,欢迎留言。如果这篇文章对你有帮助,麻烦点个赞吧。

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值