MRC Framework for Named Entity Recognition【代码解读】

75 篇文章 7 订阅
61 篇文章 2 订阅

一、预备知识

  • 介绍文章
    • A Unified MRC Framework for Named Entity Recognition【文章学习】:https://blog.csdn.net/qq_16949707/article/details/115517783?spm=1001.2014.3001.5501
    • [NLP]MRC is All you Need?
    • https://zhpmatrix.github.io/2020/05/07/mrc-is-all-you-need/
    • 中文NER任务实验小结报告——深入模型实现细节
    • https://zhuanlan.zhihu.com/p/103779616
    • A Unified MRC Framework for Named Entity Recognition
    • https://www.aclweb.org/anthology/2020.acl-main.519/
  • 代码:
    • 原文:https://github.com/ShannonAI/mrc-for-flat-nested-ner
    • 其他实现:https://github.com/qiufengyuyi/sequence_tagging
  • 相关知识
    • 损失函数|交叉熵损失函数
    • https://zhuanlan.zhihu.com/p/35709485
    • softmax及其交叉熵求导(头条面试) - 面试篇
    • https://blog.csdn.net/GreatXiang888/article/details/99293507
    • 5分钟理解Focal Loss与GHM——解决样本不平衡利器
    • https://zhuanlan.zhihu.com/p/80594704

二、 中文NER任务实验小结报告——深入模型实现细节【代码讲解】

  • 中文NER任务实验小结报告——深入模型实现细节
    • https://zhuanlan.zhihu.com/p/103779616
    • https://github.com/qiufengyuyi/sequence_tagging
  • 概览
    • 1. 针对每一类tag找到一个文本中的start和end位置。
    • 2. 定义问法pattern,每一类tag一类问法
    • 3. 拿到bert的sequence_output,mask掉问法pattern和文本之外的token,拿到预测的start和end
    • 4. 计算loss
    • 注意:该作者只实现了基于start和end点的位置,没有实现start和end点匹配的位置。

1. 数据构建方式

1.1 训练数据构建
  • file:bert_mrc_prepare_data.py
  • func: trans_orig_data_to_training_data

> 海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。 
>
> O O O O O O O B-LOC I-LOC O B-LOC I-LOC O O O O O O



1. 每一个类别以一个问法pattern
2. 一次mrc只能识别一种类别,所以构造数据的时候,label数据只是把该类别所有的start_index和end_index拿出来,所以可能有多个1.
3. 参数说明:
    data_X:包含pattern问法的输入,已经转化成token;
    data_start_Y: 起点位置,每个pattern对应一份起点位置,并且可能出现多个,这里起点位置好像没有计算pattern的长度。
    data_end_Y: 起点位置,每个pattern对应一份结束位置,并且可能出现多个,这里结束点位置好像没有计算pattern的长度。
    query_len: 问法pattern的长度
    token_type_ids: pattern的为0,原始text的为1
    总结:感觉负样本会比较多呀?
1.2 测试数据构建
  • file:bert_mrc_prepare_data.py
  • func: gen_test_data_from_orig_data
1. 每一个pattern问法都会分拆一些原始句子,然后去找对应的start和end起始位置。
2. 这里拆分句子的时候,不会根据实体位置来拆分了,因为测试数据没有实体的位置信息,所以这里有点不聪明,感觉是不是可以拿分词的结果。
3. 参数说明:
    data_X:包含pattern问法的输入,已经转化成token;
    src_test_sample_id: 就是一个行号/2,相当于每一段文本一个id
    query_class_list:对应的pattern的index
    token_type_ids_list: pattern为0, 原始text为1
    query_len_list: pattern的长度
    
1.3 构建输入
  • file:data_utils.py
  • func:data_generator_bert_mrc
- feature, label
- yield (input_x,len(input_x),query_len,token_type_id),(start_y,end_y)
    - input_x:输入,来自于data_X
    - len(input_x): input_x的长度
    - query_len: pattern问法的长度
    - token_type_id: pattern为0, 原始text为1
    - (start_y,end_y): 起始位置

2. 模型的输入输出

  • file:bert_mrc.py
  • func:bert_mrc_model_fn_builder
2.1 输入:
  • features, labels
input_ids,text_length_list,query_length_list,token_type_id_list = features
start_labels,end_labels = labels
2.2 模型构建:
  • file:bert_mrc.py
  • class:bertMRC
# 数列输出,每个token位置一个输出
bert_seq_output = bert_model.get_sequence_output() # self.sequence_output = self.all_encoder_layers[-1]

# bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu)
# bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training)
# 直接搞个全连接,找到start位置和end位置,每个位置对应一个01分类
start_logits = tf.layers.dense(bert_seq_output,self.num_labels)
end_logits = tf.layers.dense(bert_seq_output, self.num_labels)
# tf.sequence_mask:Returns a mask tensor representing the first N positions of each cell
# tf.sequence_mask:设置前n个为postive
# query_span_mask:相当于是mask 问法的pattern
query_span_mask = tf.cast(tf.sequence_mask(query_len_list),tf.int32) #
# total_seq_mask:相当于是mask有文本的字段
total_seq_mask = tf.cast(tf.sequence_mask(text_length_list),tf.int32)
query_span_mask = query_span_mask * -1
query_len_max = tf.shape(query_span_mask)[1] # 最长的问法的长度?
left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max # 剩下的文本的长度?
# query_span_mask:相当于是mask 问法的pattern
# left_query_len_max: 剩下的文本的长度?
# zero_mask_left_span这个是干啥
zero_mask_left_span = tf.zeros((tf.shape(query_span_mask)[0],left_query_len_max),dtype=tf.int32)
# query_span_mask:相当于是mask 问法的pattern
# zero_mask_left_span:concat起来做啥
final_mask = tf.concat((query_span_mask,zero_mask_left_span),axis=-1)
# query_span_mask->zero_mask_left_span->final_mask
# query_span_mask:相当于是mask 问法的pattern的mask,这里基础的*-1了
# total_seq_mask:为全部的mask所以相加后就为真实的mask了,只是中间这里是在干嘛?
final_mask = final_mask + total_seq_mask
predict_start_ids = tf.argmax(start_logits, axis=-1, name="pred_start_ids")
predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids")
if not is_testing:
    # one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32)
    # start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True)
    # end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True)

    # focal loss
    start_loss = focal_loss(start_logits,start_labels,final_mask,self.num_labels,True)
    end_loss = focal_loss(end_logits,end_labels,final_mask,self.num_labels,True)
    # 就是这两个的叠加
    final_loss = start_loss + end_loss
    return final_loss,predict_start_ids,predict_end_ids,final_mask
else:
    return predict_start_ids,predict_end_ids,final_mask

  • file: utils.py
  • func: focal_loss

解析:

  1. 解决不平衡问题的时候,可以对正负样本的loss加一个权重,focal loss也是这么做的。
  2. 但是加权重还是不能解决所有问题,对于容易的问题比较多的例子,大量的容易样本的loss和可能也会占大多数,那么可以考虑给预测概率高的样本的权重降低一些,对就是这么做的。
  3. 公式:y=1时候,-a(1-p)^blog§, y=0的时候,-(1-a)p^blog(1-p), 一般b取2,a取0.25.

def focal_loss(logits,labels,mask,num_labels,one_hot=True,lambda_param=1.5):
    # 拿到预测的结果
    probs = tf.nn.softmax(logits,axis=-1)
    # 预测为1的结果
    pos_probs = probs[:,:,1]
    # pos的label和neg的label
    # 这里是啥操作,还有些没搞懂
    prob_label_pos = tf.where(tf.equal(labels,1),pos_probs,tf.ones_like(pos_probs))
    prob_label_neg = tf.where(tf.equal(labels,0),pos_probs,tf.zeros_like(pos_probs))
    # 计算focal loss
    # (1-prob_pos)^lambda_param * log(prob_pos + 1e-7) +
    # (prob_neg)^lambda_param * log(1- prob_neg + 1e-7) +
    # 感觉这里负样本是超级不平衡,用focal loss可能是会好一些
    loss = tf.pow(1. - prob_label_pos,lambda_param)*tf.log(prob_label_pos + 1e-7) + \
           tf.pow(prob_label_neg,lambda_param)*tf.log(1. - prob_label_neg + 1e-7)
    # 这里再衬衣一个amsk
    loss = -loss * tf.cast(mask,tf.float32)
    loss = tf.reduce_sum(loss,axis=-1,keepdims=True)
    # loss = loss/tf.cast(tf.reduce_sum(mask,axis=-1),tf.float32)
    loss = tf.reduce_mean(loss)
    return loss
2.3 模型预测
  • file:load_and_predict.py
  • class:fastPredictBertMrc
# 1. 拿到start_ids和end_ids
predictions = self.predict_fn({'words': [text], 'text_length': [text_length],'query_length':[query_len],'token_type_ids':[token_type_ids]})
start_ids,end_ids = predictions.get("start_ids"),predictions.get("end_ids")
#return start_ids[0],end_ids[0]

# 2. 根据开始,结尾标识,找到对应的实体
# 1.找到一个start=1的位置,往后找看有没有end=1的位置
# 2.如果找end=1的过程中,没找到end,反而又找到一个start=1,那么直接跳过,跳出循环,可能是单个字
# 3.找到一个匹配的end了,拿到实体,跳出循环,找下一个start和end对
# 4.如果到末尾都没找到匹配的end或者又找到一个新的start,这个时候会跳出循环,那么原来的start位置可以作为一个单字
# 5.??? 感觉会很粗糙啊,这种实现方式

def extract_entity_from_start_end_ids(self,orig_text,start_ids,end_ids):
    # 根据开始,结尾标识,找到对应的实体
    # start_ids是啥,是这样的吗[0,0,0,1,0,0,1,0,0]吗
    # end_ids是啥,是这样的吗  [0,0,0,0,0,1,0,0,1]吗
    entity_list = []
    for i,start_id in enumerate(start_ids):
        # 为啥跳过起点?
        if start_id == 0:
            continue
        # j从i+1开始
        j = i+1
        find_end_tag = False
        while j < len(end_ids):
            # 若在遇到end=1之前遇到了新的start=1,则停止该实体的搜索
            if start_ids[j] == 1:
                break
            # 匹配到一个
            if end_ids[j] == 1:
                # 把实体拿出来
                entity_list.append("".join(orig_text[i:j+1]))
                find_end_tag = True
                # 跳出来,找下一个start=1的位置
                break
            else:
                j+=1
        # 如果到末尾都没找到匹配的end或者又找到一个新的start,这个时候会跳出循环,那么原来的start位置可以作为一个单字
        if not find_end_tag:
            # 实体就一个单字
            entity_list.append("".join(orig_text[i:i+1]))
    return entity_list      
  • file:load_and_predict.py
  • func:predict_entitys_for_all_sample
# 每个长句子是拆解成多个短句子,并且每个句子又分为多个tag的类别,所以要一直合并。
# 1. 先合并同一个句子同一个类别的
# 2. 对于不同类别的,要先处理上一个类别的实体
# 3. sample_id都不同,代表句子都不同了,先把上一个处理完,并更新buffer里面的内容
# 5. 注意处理最后一个样本


    def predict_entitys_for_all_sample(self,text_data_Xs,query_lens,token_type_ids_list,query_class_list,src_sample_ids_list,orig_text_list):
        result_list = [] # 存储的是每个样本每个实体类别对应的实体列表,有可能是空的
        cur_sample_id_buffer = 0
        start_ids_buffer = []
        end_ids_buffer = []
        query_class_buffer = ""
        for i in range(len(text_data_Xs)):
            cur_text = text_data_Xs[i]
            cur_query_len = query_lens[i]
            cur_token_type_ids = token_type_ids_list[i]
            cur_query_class = query_class_list[i]
            cur_src_sample_id = src_sample_ids_list[i]
            start_ids,end_ids = self.predict_mrc(cur_text,cur_query_len,cur_token_type_ids)
            # 去掉query
            # print(type(start_ids))
            # 拿到真正的start,end,和label
            true_start_ids = start_ids[cur_query_len:].tolist()
            true_end_ids = end_ids[cur_query_len:].tolist()
            cur_query_class_str = ner_query_map.get("tags")[cur_query_class]
            # 首个样本?知道了,每个长句子是拆解成多个短句子,所以
            if query_class_buffer == "" or len(start_ids_buffer)==0:
                # 首个样本,都添加到buffer中
                query_class_buffer = cur_query_class_str
                start_ids_buffer.extend(true_start_ids)
                end_ids_buffer.extend(true_end_ids)
                cur_sample_id_buffer = cur_src_sample_id # 每个句子应该是一个id
            elif cur_src_sample_id == cur_sample_id_buffer and cur_query_class_str == query_class_buffer:
                # 同一个样本,同一个query,要合并
                start_ids_buffer.extend(true_start_ids)
                end_ids_buffer.extend(true_end_ids)
            elif cur_src_sample_id == cur_sample_id_buffer:
                # 遇到不同query 类型,先处理上一个query类型的样本实体识别
                cur_orig_text = orig_text_list[cur_sample_id_buffer]

                extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text,start_ids_buffer,end_ids_buffer)
                # print(result_list)
                # print(cur_src_sample_id)
                # print(cur_orig_text)
                if len(result_list) == 0:
                    # 初始情况
                    # buffer 的query class 更新
                    result_list.append({query_class_buffer:extracted_entity_list})
                else:
                    if cur_sample_id_buffer >= len(result_list):
                        result_list.append({query_class_buffer: extracted_entity_list})
                    else:
                        result_list[cur_sample_id_buffer].update({query_class_buffer:extracted_entity_list})
                # 更新query_class_buffer
                query_class_buffer = cur_query_class_str
                # 更新start_ids_buffer,end_ids_buffer
                start_ids_buffer = true_start_ids
                end_ids_buffer = true_end_ids
            else:
                # 本轮为新的样本
                cur_orig_text = orig_text_list[cur_sample_id_buffer]
                extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text, start_ids_buffer,
                                                                               end_ids_buffer)
                # if cur_src_sample_id == 2:
                #     print(extracted_entity_list)
                # 更新上一个id的样本实体抽取
                # print(cur_sample_id_buffer)
                # print(result_list)
                if cur_sample_id_buffer >= len(result_list):
                    result_list.append({query_class_buffer: extracted_entity_list})
                else:
                    result_list[cur_sample_id_buffer].update({query_class_buffer: extracted_entity_list})
                query_class_buffer = cur_query_class_str
                start_ids_buffer = true_start_ids
                end_ids_buffer = true_end_ids
                cur_sample_id_buffer = cur_src_sample_id
        # deal with last sample
        cur_orig_text = orig_text_list[cur_sample_id_buffer]
        extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text, start_ids_buffer,
                                                                       end_ids_buffer)
        if cur_sample_id_buffer >= len(result_list):
            result_list.append({query_class_buffer: extracted_entity_list})
        else:
            result_list[cur_sample_id_buffer].update({query_class_buffer: extracted_entity_list})
        return result_list
2.5 其他函数
2.5.1 根据BIO数据生成实体
  • file:load_and_predict.py
  • func:gen_entity_from_label_id_list
def gen_entity_from_label_id_list(text_lists,label_id_list,id2slot_dict,orig_test=False):
    """
    B-LOC
    B-PER
    B-ORG
    I-LOC
    I-ORG
    I-PER
    :param label_id_list:
    :param id2slot_dict:
    :return:

    text_list : [["北","京","的","天","安","门"]]
    label_list : [["B-","I-","O","B-","I-","I-"]]
    outputs: ["北京", "天安门"]
    """
    entity_list = []
    # 存index
    buffer_list = []
    for i,label_ids in enumerate(label_id_list):
        # 拿到当前句子和当前label
        cur_entity_list = [] # ["北京", "天安门"]
        if not orig_test:
            label_list = [id2slot_dict.get(label_ele) for label_ele in label_ids]
        else:
            label_list = label_ids
        text_list = text_lists[i]
        # label_list
        # print(label_list)
        # 遍历当前label
        for j,label in enumerate(label_list):
            # 遇到O了,如果buffer里面有数据,添加实体
            if not label.__contains__("-"):
                if len(buffer_list)==0:
                    continue
                else:
                    # print(buffer_list)
                    # print(text_list)
                    buffer_char_list = [text_list[index] for index in buffer_list]
                    buffer_word = "".join(buffer_char_list)
                    cur_entity_list.append(buffer_word)
                    buffer_list.clear()
            else:
                # 如果buffer里面没有数据
                if len(buffer_list) == 0:
                    # 如果遇到B开始了,更新buffer
                    if label.startswith("B"):
                        #必须以B开头,否则说明有问题,不能加入
                        buffer_list.append(j)
                # 如果有数据
                else:
                    # 检查最后一个index的label
                    buffer_last_index = buffer_list[-1]
                    buffer_last_label = label_list[buffer_last_index]
                    split_label = buffer_last_label.split("-")
                    # B-ORG,一个是起点,中间位置标记,一个是tag的类别
                    buffer_last_label_prefix,buffer_last_label_type = split_label[0],split_label[1]
                    # 拿到当前的位置的B和tag类别
                    cur_label_split = label.split("-")
                    cur_label_prefix,cur_label_type = cur_label_split[0],cur_label_split[1]
                    # B+B
                    # 两个都为B,把
                    if buffer_last_label_prefix=="B" and cur_label_prefix=="B":
                        # 相当于是搞了一个单字的
                        cur_entity_list.append(text_list[buffer_list[-1]])
                        buffer_list.clear()
                        buffer_list.append(j)
                    # 遇到一个新的实体,加上老的实体并更新buffer
                    elif buffer_last_label_prefix=="I" and cur_label_prefix=="B":
                        buffer_char_list = [text_list[index] for index in buffer_list]
                        buffer_word = "".join(buffer_char_list)
                        cur_entity_list.append(buffer_word)
                        buffer_list.clear()
                        buffer_list.append(j)
                    # B和I对的上,看下tag的类别对不对得上
                    elif buffer_last_label_prefix=="B" and cur_label_prefix=="I":
                        # analyze type
                        # 类型相同直接加上去
                        if buffer_last_label_type == cur_label_type:
                            buffer_list.append(j)
                        else:
                            # 类型不同,加上上一个,当前的有问题就不加了
                            cur_entity_list.append(text_list[buffer_list[-1]])
                            buffer_list.clear()
                            # 这种情况出现在预测有问题,即一个I的label不应当作为一个实体的起始。
                            #buffer_list.append(j)
                    else:
                        # I + I
                        # analyze type
                        # 类型相同可以加
                        if buffer_last_label_type == cur_label_type:
                            buffer_list.append(j)
                        else:
                            # 不同的话,加上上一个,感觉后面那个i也可以不加了啊
                            cur_entity_list.append(text_list[buffer_list[-1]])
                            buffer_list.clear()
                            buffer_list.append(j)
        # 最后一个
        if buffer_list:
            buffer_char_list = [text_list[index] for index in buffer_list]
            buffer_word = "".join(buffer_char_list)
            cur_entity_list.append(buffer_word)
            buffer_list.clear()
        # 加上实体
        entity_list.append(cur_entity_list)
    return entity_list

2.5.2 计算实体识别的metric
def cal_mertric_from_two_list(prediction_list,true_list):
    tp, fp, fn = 0, 0, 0
    for pred_entity, true_entity in zip(prediction_list, true_list):
        pred_entity_set = set(pred_entity)
        true_entity_set = set(true_entity)
        tp += len(true_entity_set & pred_entity_set)
        fp += len(pred_entity_set - true_entity_set)
        fn += len(true_entity_set - pred_entity_set)
    # 召回的精度
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0
    # 召回
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0
    # f2=2*precision*recall/(precision+recall)
    f1 = 2 * prec * rec / (prec + rec)
    print("span_level pre micro_avg:{}".format(prec))
    print("span_level rec micro_avg:{}".format(rec))
    print("span_level f1 micro_avg:{}".format(f1))

三、论文原始代码,写得挺好的

一、Train

  • 基础介绍
    The main training procedure is in trainer.py

Examples to start training are in scripts/reproduce.

Note that you may need to change DATA_DIR, BERT_DIR, OUTPUT_DIR to your own
dataset path, bert model path and log path, respectively.

1. 数据转化
  1. 拿到每个实体的类型以及起始位置信息。
  2. 对于每一段文本,遍历tag的类别,对每一个tag类别抽取起始位置label和问法的pattern,以及原始的context。

核心代码:

for label, query in tag2query.items():
    mrc_samples.append(
        {
            "context": src,
            "start_position": [tag.begin for tag in tags if tag.tag == label],
            "end_position": [tag.end-1 for tag in tags if tag.tag == label],
            "query": query
        }
    )
2. 训练数据生成
  1. 拼接query和context,生成token
  2. 修正起始位置(前面加了query,并且英文分词用了BertWordPieceTokenizer)
  3. 生成match label,相当于每一对start和end之间的match lable都为1,其他都为0
  4. padding
class MRCNERDataset(Dataset):
    """
    MRC NER Dataset
    Args:
        json_path: path to mrc-ner style json
        tokenizer: BertTokenizer
        max_length: int, max length of query+context
        possible_only: if True, only use possible samples that contain answer for the query/context
        is_chinese: is chinese dataset
    """
    def __init__(self, json_path, tokenizer: BertWordPieceTokenizer, max_length: int = 128, possible_only=False,
                 is_chinese=False, pad_to_maxlen=False):
        self.all_data = json.load(open(json_path, encoding="utf-8"))
        self.tokenzier = tokenizer
        self.max_length = max_length
        self.possible_only = possible_only
        if self.possible_only:
            self.all_data = [
                x for x in self.all_data if x["start_position"]
            ]
        self.is_chinese = is_chinese
        self.pad_to_maxlen = pad_to_maxlen

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, item):
        """
        Args:
            item: int, idx
        Returns:
            tokens: tokens of query + context, [seq_len]
            token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
            start_labels: start labels of NER in tokens, [seq_len]
            end_labels: end labelsof NER in tokens, [seq_len]
            label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
            match_labels: match labels, [seq_len, seq_len]
            sample_idx: sample id
            label_idx: label id

        """
        data = self.all_data[item]
        tokenizer = self.tokenzier

        # 这个有啥用
        qas_id = data.get("qas_id", "0.0")
        sample_idx, label_idx = qas_id.split(".")
        sample_idx = torch.LongTensor([int(sample_idx)])
        label_idx = torch.LongTensor([int(label_idx)])

        # 原始数据
        query = data["query"]
        context = data["context"]
        start_positions = data["start_position"]
        end_positions = data["end_position"]

        if self.is_chinese:
            # 这是个啥,把空格去掉了吗?
            context = "".join(context.split())
            # 修正一下end位置
            end_positions = [x+1 for x in end_positions]
        else:
            # add space offsets
            # 英文的话,在计算起始位置的时候,要加上空格的数量
            words = context.split()
            start_positions = [x + sum([len(w) for w in words[:x]]) for x in start_positions]
            end_positions = [x + sum([len(w) for w in words[:x + 1]]) for x in end_positions]

        # 将query和context放进去,转化成为token
        query_context_tokens = tokenizer.encode(query, context, add_special_tokens=True)
        tokens = query_context_tokens.ids
        type_ids = query_context_tokens.type_ids # query为0,context为1
        offsets = query_context_tokens.offsets # 这个是个啥,mask?

        # find new start_positions/end_positions, considering
        # 1. we add query tokens at the beginning
        # 2. word-piece tokenize
        # 要重新计算一下start label和end label, 原因如下:
        # 1. 因为加了query的tokens放在前面
        # 2. 使用了word-piece分词
        origin_offset2token_idx_start = {}
        origin_offset2token_idx_end = {}
        for token_idx in range(len(tokens)):
            # skip query tokens
            if type_ids[token_idx] == 0:
                continue
            # 没搞懂这是个啥
            token_start, token_end = offsets[token_idx]
            # skip [CLS] or [SEP]
            if token_start == token_end == 0:
                continue
            # 拿到每个位置的offset
            origin_offset2token_idx_start[token_start] = token_idx
            origin_offset2token_idx_end[token_end] = token_idx
        # 拿到新的start position和end position
        # 估计中文的不受影响
        # 另外感觉这个tokenizer是自己设计了的,返回了offset
        new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
        new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]

        label_mask = [
            # 前面的代表query需要mask
            # offsets[token_idx] == (0, 0) 这个没搞懂
            (0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)
            for token_idx in range(len(tokens))
        ]
        # 两个mask其实是一样的
        start_label_mask = label_mask.copy()
        end_label_mask = label_mask.copy()

        # the start/end position must be whole word
        # 对于非中文,还需要检查起始位置是否为字
        if not self.is_chinese:
            for token_idx in range(len(tokens)):
                current_word_idx = query_context_tokens.words[token_idx]
                next_word_idx = query_context_tokens.words[token_idx+1] if token_idx+1 < len(tokens) else None
                prev_word_idx = query_context_tokens.words[token_idx-1] if token_idx-1 > 0 else None
                if prev_word_idx is not None and current_word_idx == prev_word_idx:
                    start_label_mask[token_idx] = 0
                if next_word_idx is not None and current_word_idx == next_word_idx:
                    end_label_mask[token_idx] = 0

        assert all(start_label_mask[p] != 0 for p in new_start_positions)
        assert all(end_label_mask[p] != 0 for p in new_end_positions)

        assert len(new_start_positions) == len(new_end_positions) == len(start_positions)
        assert len(label_mask) == len(tokens)
        # new_start_positions 这里只存了start的集合
        # new_end_positions 也只存了end的集合
        start_labels = [(1 if idx in new_start_positions else 0)
                        for idx in range(len(tokens))]
        end_labels = [(1 if idx in new_end_positions else 0)
                      for idx in range(len(tokens))]

        # truncate
        # 截断
        tokens = tokens[: self.max_length]
        type_ids = type_ids[: self.max_length]
        start_labels = start_labels[: self.max_length]
        end_labels = end_labels[: self.max_length]
        start_label_mask = start_label_mask[: self.max_length]
        end_label_mask = end_label_mask[: self.max_length]

        # make sure last token is [SEP]
        # 末尾为啥是个sep呀
        sep_token = tokenizer.token_to_id("[SEP]")
        if tokens[-1] != sep_token:
            assert len(tokens) == self.max_length
            tokens = tokens[: -1] + [sep_token]
            start_labels[-1] = 0
            end_labels[-1] = 0
            start_label_mask[-1] = 0
            end_label_mask[-1] = 0

        if self.pad_to_maxlen:
            tokens = self.pad(tokens, 0)
            type_ids = self.pad(type_ids, 1)
            start_labels = self.pad(start_labels)
            end_labels = self.pad(end_labels)
            start_label_mask = self.pad(start_label_mask)
            end_label_mask = self.pad(end_label_mask)

        seq_len = len(tokens)
        match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
        # match的label,相当于是一对起始位置,他们的label才为1,其余都为0
        for start, end in zip(new_start_positions, new_end_positions):
            if start >= seq_len or end >= seq_len:
                continue
            match_labels[start, end] = 1

        return [
            torch.LongTensor(tokens),
            torch.LongTensor(type_ids),
            torch.LongTensor(start_labels),
            torch.LongTensor(end_labels),
            torch.LongTensor(start_label_mask),
            torch.LongTensor(end_label_mask),
            match_labels,
            sample_idx,
            label_idx
        ]


3. 模型
  1. 全连接拿到起始位置的logit
  2. 对于每一个起始位置,再判定是否是一个实体的起始位置,通过一个MultiNonLinearClassifier拿到match的logit

import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel

from models.classifier import MultiNonLinearClassifier, SingleLinearClassifier


class BertQueryNER(BertPreTrainedModel):
    def __init__(self, config):
        super(BertQueryNER, self).__init__(config)
        self.bert = BertModel(config)

        # self.start_outputs = nn.Linear(config.hidden_size, 2)
        # self.end_outputs = nn.Linear(config.hidden_size, 2)
        # 线性分类,就是全连接吗
        self.start_outputs = nn.Linear(config.hidden_size, 1)
        self.end_outputs = nn.Linear(config.hidden_size, 1)
        # MultiNonLinearClassifier这个是个啥
        self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout)
        # self.span_embedding = SingleLinearClassifier(config.hidden_size * 2, 1)

        self.hidden_size = config.hidden_size

        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        Args:
            input_ids: bert input tokens, tensor of shape [seq_len]
            token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]
            attention_mask: attention mask, tensor of shape [seq_len]
        Returns:
            start_logits: start/non-start probs of shape [seq_len]
            end_logits: end/non-end probs of shape [seq_len]
            match_logits: start-end-match probs of shape [seq_len, 1]
        """

        # bert的输出
        bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

        sequence_heatmap = bert_outputs[0]  # [batch, seq_len, hidden] # 应该是每个位置一个hidden大小的
        batch_size, seq_len, hid_size = sequence_heatmap.size()

        # 通过全连接,拿到每个位置是否为起点,以及是否为终点的logits
        start_logits = self.start_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1]
        end_logits = self.end_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1]

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.
        # [batch, seq_len, seq_len, hidden]
        # 对于每一个位置(i,j)位置都有可能是一个候选的起始位置对
        # 所以拿这个去预测这个i,j是否是一个实体的起始位置
        start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
        # [batch, seq_len, seq_len, hidden]
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # [batch, seq_len, seq_len, hidden*2]
        span_matrix = torch.cat([start_extend, end_extend], 3)
        # [batch, seq_len, seq_len]
        span_logits = self.span_embedding(span_matrix).squeeze(-1)

        return start_logits, end_logits, span_logits


class MultiNonLinearClassifier(nn.Module):
    def __init__(self, hidden_size, num_label, dropout_rate):
        super(MultiNonLinearClassifier, self).__init__()
        self.num_label = num_label
        self.classifier1 = nn.Linear(hidden_size, hidden_size)
        self.classifier2 = nn.Linear(hidden_size, num_label)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        # features_output1 = F.relu(features_output1)
        features_output1 = F.gelu(features_output1)
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2
4. 训练
  1. 拿到start和end的logit并计算loss
  2. 计算start和end的mask,注意match也logit也可以计算mask,可以根据起点必须小于终点来mask
  3. 计算matchloss的时候,有多种选择,一种是计算label为1的位置的loss,一种是计算pred为1或者label为1位置的loss,是不是还有其他的?
  4. 两种计算loss的方式,一种是交叉熵,一种是diceloss,diceloss是2*交集/(a+b)
    def compute_loss(self, start_logits, end_logits, span_logits,
                     start_labels, end_labels, match_labels, start_label_mask, end_label_mask):
        batch_size, seq_len = start_logits.size()

        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)
        match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)
        match_label_mask = match_label_row_mask & match_label_col_mask
        # match也有一个mask,起点位置肯定要小于结束的位置
        match_label_mask = torch.triu(match_label_mask, 0)  # start should be less equal to end

        # 对于所有的candidates都计算loss
        if self.span_loss_candidates == "all":
            # naive mask
            float_match_label_mask = match_label_mask.view(batch_size, -1).float()
        else:
            # 只对golden的起始位置计算loss
            # use only pred or golden start/end to compute match loss
            start_preds = start_logits > 0
            end_preds = end_logits > 0
            if self.span_loss_candidates == "gold":
                # 只对label的位置计算loss
                match_candidates = ((start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
            else:
                # 对预测为1或者label为1的位置才计算loss
                match_candidates = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
                )
                # 这里感觉还有其他方法呀,负样本有点多,可以用下focal loss来计算?
            match_label_mask = match_label_mask & match_candidates
            float_match_label_mask = match_label_mask.view(batch_size, -1).float()
        if self.loss_type == "bce":
            # 交叉熵
            start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float())
            start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()
            # 交叉熵
            end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float())
            end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()
            # 还是交叉熵
            match_loss = self.bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())
            match_loss = match_loss * float_match_label_mask
            match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)
        else:
            # 下面是diceloss
            start_loss = self.dice_loss(start_logits, start_labels.float(), start_float_label_mask)
            end_loss = self.dice_loss(end_logits, end_labels.float(), end_float_label_mask)
            match_loss = self.dice_loss(span_logits, match_labels.float(), float_match_label_mask)

        return start_loss, end_loss, match_loss





# encoding: utf-8


import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional


class DiceLoss(nn.Module):
    """
    Dice coefficient for short, is an F1-oriented statistic used to gauge the similarity of two sets.
    Given two sets A and B, the vanilla dice coefficient between them is given as follows:
        Dice(A, B)  = 2 * True_Positive / (2 * True_Positive + False_Positive + False_Negative)
                    = 2 * |A and B| / (|A| + |B|)

    Math Function:
        U-NET: https://arxiv.org/abs/1505.04597.pdf
        dice_loss(p, y) = 1 - numerator / denominator
            numerator = 2 * \sum_{1}^{t} p_i * y_i + smooth
            denominator = \sum_{1}^{t} p_i + \sum_{1} ^{t} y_i + smooth
        if square_denominator is True, the denominator is \sum_{1}^{t} (p_i ** 2) + \sum_{1} ^{t} (y_i ** 2) + smooth
        V-NET: https://arxiv.org/abs/1606.04797.pdf
    Args:
        smooth (float, optional): a manual smooth value for numerator and denominator.
        square_denominator (bool, optional): [True, False], specifies whether to square the denominator in the loss function.
        with_logits (bool, optional): [True, False], specifies whether the input tensor is normalized by Sigmoid/Softmax funcs.
            True: the loss combines a `sigmoid` layer and the `BCELoss` in one single class.
            False: the loss contains `BCELoss`.
    Shape:
        - input: (*)
        - target: (*)
        - mask: (*) 0,1 mask for the input sequence.
        - Output: Scalar loss
    Examples:
        >>> loss = DiceLoss()
        >>> input = torch.randn(3, 1, requires_grad=True)
        >>> target = torch.empty(3, dtype=torch.long).random_(5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    def __init__(self,
                 smooth: Optional[float] = 1e-8,
                 square_denominator: Optional[bool] = False,
                 with_logits: Optional[bool] = True,
                 reduction: Optional[str] = "mean") -> None:
        super(DiceLoss, self).__init__()

        self.reduction = reduction
        self.with_logits = with_logits
        self.smooth = smooth
        self.square_denominator = square_denominator

    def forward(self,
                input: Tensor,
                target: Tensor,
                mask: Optional[Tensor] = None) -> Tensor:

        flat_input = input.view(-1)
        flat_target = target.view(-1)

        if self.with_logits:
            flat_input = torch.sigmoid(flat_input)

        if mask is not None:
            mask = mask.view(-1).float()
            flat_input = flat_input * mask
            flat_target = flat_target * mask
        # 2*交集/(A+B)
        interection = torch.sum(flat_input * flat_target, -1)
        if not self.square_denominator:
            return 1 - ((2 * interection + self.smooth) /
                        (flat_input.sum() + flat_target.sum() + self.smooth))
        else:
            return 1 - ((2 * interection + self.smooth) /
                        (torch.sum(torch.square(flat_input,), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth))

    def __str__(self):
        return f"Dice Loss smooth:{self.smooth}"

二、Test

这个有些没看懂啊,测试这么简单的?
可以看下这个:
pytorch_lightning 全程笔记:https://zhuanlan.zhihu.com/p/319810661


# encoding: utf-8


import os
from pytorch_lightning import Trainer

from trainer import BertLabeling


def evaluate(ckpt, hparams_file):
    """main"""

    trainer = Trainer(gpus=[0, 1], distributed_backend="ddp")

    model = BertLabeling.load_from_checkpoint(
        checkpoint_path=ckpt,
        hparams_file=hparams_file,
        map_location=None,
        batch_size=1,
        max_length=128,
        workers=0
    )
    trainer.test(model=model)


if __name__ == '__main__':
    # ace04
    HPARAMS = "/mnt/mrc/train_logs/ace2004/ace2004_20200911reproduce_epoch15_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_0/hparams.yaml"
    CHECKPOINTS = "/mnt/mrc/train_logs/ace2004/ace2004_20200911reproduce_epoch15_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=10_v0.ckpt"
    # DIR = "/mnt/mrc/train_logs/ace2004/ace2004_20200910_lr3e-5_drop0.3_bert0.1_bsz32_hard_loss_bce_weight_span0.05"
    # CHECKPOINTS = [os.path.join(DIR, x) for x in os.listdir(DIR)]

    # ace04-large
    HPARAMS = "/mnt/mrc/train_logs/ace2004/ace2004_20200910reproduce_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_2/hparams.yaml"
    CHECKPOINTS = "/mnt/mrc/train_logs/ace2004/ace2004_20200910reproduce_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=10.ckpt"

    # ace05
    # HPARAMS = "/mnt/mrc/train_logs/ace2005/ace2005_20200911_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_0/hparams.yaml"
    # CHECKPOINTS = "/mnt/mrc/train_logs/ace2005/ace2005_20200911_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=15.ckpt"

    # zh_msra
    CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt"
    HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml"


    evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS)

三、其他函数

1. BMES解码

功能:将BMES类label转化为一个个实体标签

  • file :bmes_decode.py

def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]:
    """
    decode inputs to tags
    Args:
        char_label_list: list of tuple (word, bmes-tag)
    Returns:
        tags
    Examples:
        >>> x = [("Hi", "O"), ("Beijing", "S-LOC")]
        >>> bmes_decode(x)
        [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}]
    """
    idx = 0
    length = len(char_label_list)
    tags = []
    while idx < length:
        term, label = char_label_list[idx]
        current_label = label[0]

        # correct labels
        # BMES->起点,中点,结束点,单个位置
        # 当前如果为M或者E,都代表还是这个实体,将其置为B
        if current_label in ["M", "E"]:
            current_label = "B"
        # 到达终点点,并且当前lable还是B,感觉可以回收一下这个实体
        if idx + 1 == length and current_label == "B":
            current_label = "S"

        # merge chars
        # 如果为O,跳过
        if current_label == "O":
            idx += 1
            continue
        # 如果为S,回收实体
        if current_label == "S":
            tags.append(Tag(term, label[2:], idx, idx + 1))
            idx += 1
            continue
        # 如果为B,这里也合并了M和E
        if current_label == "B":
            # 往后找
            end = idx + 1
            # 如果为M,可以继续往后找
            while end + 1 < length and char_label_list[end][1][0] == "M":
                end += 1
            # 如果为E
            if char_label_list[end][1][0] == "E":  # end with E
                # 回收实体
                entity = "".join(char_label_list[i][0] for i in range(idx, end + 1))
                tags.append(Tag(entity, label[2:], idx, end + 1))
                idx = end + 1
            else:  # end with M/B
                # 如果BM都找完了,但是没有E,也回收一下实体
                entity = "".join(char_label_list[i][0] for i in range(idx, end))
                tags.append(Tag(entity, label[2:], idx, end))
                idx = end
            continue
        else:
            raise Exception("Invalid Inputs")
    return tags

四、自己重构与实现

待完成

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值