中文短文本的实体链指

中文短文本的实体链指任务

1. 任务描述

本评测任务围绕实体链指技术,结合其对应的AI智能应用需求,在CCKS 2019面向中文短文本的实体链指任务的基础上进行了拓展与改进,主要改进包括以下几部分:
(1)去掉实体识别,专注于中文短文本场景下的多歧义实体消歧技术;
(2)增加对新实体(NIL实体)的上位概念类型判断;
(3)对标注文本数据调整,增加多模任务场景下的文本源,同时调整了多歧义实体比例。
面向中文短文本的实体链指,简称EL(Entity Linking)。即对于给定的一个中文短文本(如搜索Query、微博、对话内容、文章/视频/图片的标题等),EL将其中的实体与给定知识库中对应的实体进行关联。
传统的实体链指任务主要针对长文本,长文本拥有丰富的上下文信息,能辅助实体进行歧义消解并完成实体链指,相比之下,针对中文短文本的实体链指存在很大的挑战,主要原因如下:
(1)口语化严重,导致实体歧义消解困难;
(2)短文本上下文语境不丰富,须对上下文语境进行精准理解;
(3)相比英文,中文由于语言自身的特点,在短文本的链指问题上更有挑战。
此次任务的输入输出定义如下:
输入:
中文短文本以及该短文本中的实体集合。
输出:
输出文本此中文短文本的实体链指结果。每个结果包含:实体mention、在中文短文本中的位置偏移、其在给定知识库中的id,如果为NIL情况,需要再给出实体的上位概念类型(封闭体系的概念详见附件)
示例输入:
在这里插入图片描述
示例输出:
在这里插入图片描述

说明:对于实体有歧义的查询,系统应该有能力来区分知识库中链接的候选实体中哪个实体为正确链指的实体结果。例如,知识库中有8个不同的实体都可能是『琅琊榜』的正确链指结果,因为知识库中的这8个实体都可以通过『琅琊榜』的字面表达查找到,但是我们在给定的上下文中(『海燕』、『原创小说』、『权谋小说』),有足够的信息去区分这些候选实体中,哪个才是应该被关联上的结果。

2. 数据描述
2.1. 知识库

该任务知识库来自百度百科知识库。知识库中的每个实体都包含一个subject_id(知识库id),一个subject名称,实体的别名,对应的概念类型,以及与此实体相关的一系列二元组< predicate,object>(<属性,属性值>)信息形式。知识库中每行代表知识库的一条记录(一个实体信息),每条记录为json数据格式。
示例如下所示:
在这里插入图片描述

2.2 标准数据集

标注数据集由训练集、验证集和测试集组成,整体标注数据大约10万条左右,数据均通过百度众包标注生成,详细标注质量将会在数据发布时一并给出。
标注数据集中每条数据的格式为:
在这里插入图片描述

标注数据集主要来自于:真实的互联网网页标题数据、视频标题数据、搜索Query
标注文本对象的示例数据如下:
在这里插入图片描述

3. 评价指标

在这里插入图片描述

import torch
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
There are 1 GPU(s) available.
We will use the GPU: GeForce GTX 1070
编写路径
import os
import json
import logging
import random
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from transformers import (
    DataProcessor,
    InputExample,
    BertConfig,
    BertTokenizer,
    BertForSequenceClassification,
    glue_convert_examples_to_features,
)

DEVICE = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

# 预训练模型路径
PRETRAINED_PATH = './chinese_roberta_wwm_ext_pytorch/'
# 实体链接训练路径
EL_SAVE_PATH = './pytorch-lightning-checkpoints/EntityLinking/'
# 实体类别推断训练路径
ET_SAVE_PATH = './pytorch-lightning-checkpoints/EntityTyping/'

# 项目数据路径
DATA_PATH = './data/'

# CCKS2020实体链指竞赛原始路径
RAW_PATH = DATA_PATH + 'ccks2020_el_data_v1/'

# 预处理后导出的pickle文件路径
PICKLE_PATH = DATA_PATH + 'pickle/'
if not os.path.exists(PICKLE_PATH):
    os.mkdir(PICKLE_PATH)

# 预测结果的文件路径
RESULT_PATH = DATA_PATH + 'result/'
if not os.path.exists(RESULT_PATH):
    os.mkdir(RESULT_PATH)

# 训练、验证、推断所需的tsv文件路径
TSV_PATH = DATA_PATH + 'tsv/'
if not os.path.exists(TSV_PATH):
    os.mkdir(TSV_PATH)

# 训练结果的CheckPoint文件路径
CKPT_PATH = './ckpt/'

PICKLE_DATA = {
    # 实体名称对应的KBID列表
    'ENTITY_TO_KBIDS': None,
    # KBID对应的实体名称列表
    'KBID_TO_ENTITIES': None,
    # KBID对应的属性文本
    'KBID_TO_TEXT': None,
    # KBID对应的实体类型列表(注意:一个实体可能对应'|'分割的多个类型)
    'KBID_TO_TYPES': None,
    # KBID对应的属性列表
    'KBID_TO_PREDICATES': None,

    # 索引类型映射列表
    'IDX_TO_TYPE': None,
    # 类型索引映射字典
    'TYPE_TO_IDX': None,
}

for k in PICKLE_DATA:
    filename = k + '.pkl'
    if os.path.exists(PICKLE_PATH + filename):
        PICKLE_DATA[k] = pd.read_pickle(PICKLE_PATH + filename)
    else:
        print(f'File {filename} not Exist!')


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

处理实体数据,把空数据去掉
logger = logging.getLogger(__name__)

class PicklePreprocessor:
    """生成全局变量Pickle文件的预处理器"""

    def __init__(self):
               
        # 实体名称对应的KBID列表  {"张健"} -> "10001"
        self.entity_to_kbids = defaultdict(set)
        
        # KBID对应的实体名称列表 "10001" -> {"张健"}
        self.kbid_to_entities = dict()
        
        # KBID对应的属性文本  "10001" -> {"政治面貌:中共党员","义项描述:潜山县塔畈乡副主任科员、纪委副书记","性别:男",
        # "学历:大专","中文名:张健"} 
        self.kbid_to_text = dict()
        
        # KBID对应的实体类型列表 "10001" -> {"Person"}
        self.kbid_to_types = dict()
        
        # KBID对应的属性列表 "10001" -> {"政治面貌","义项描述","性别","学历","中文名"} 
        self.kbid_to_predicates = dict()

        # 索引类型映射列表 ["Person"]
        self.idx_to_type = list()
        
        # 类型索引映射字典 {"Person":0}
        self.type_to_idx = dict()

    def run(self, shuffle_text=True):
        with open(RAW_PATH + 'kb.json', 'r',encoding='utf-8') as f:
            for line in tqdm(f):
                line = json.loads(line)

                kbid = line['subject_id']
                # 将实体名与别名合并
                entities = set(line['alias'])
                entities.add(line['subject'])
                for entity in entities:
                    self.entity_to_kbids[entity].add(kbid)
                self.kbid_to_entities[kbid] = entities

                text_list, predicate_list = [], []
                for x in line['data']:
                    # 简单拼接predicate与object,这部分可以考虑别的方法尝试
                    text_list.append(':'.join([x['predicate'].strip(), x['object'].strip()]))
                    predicate_list.append(x['predicate'].strip())
                if shuffle_text:  # 对属性文本随机打乱顺序
                    random.shuffle(text_list)
                self.kbid_to_predicates[kbid] = predicate_list
                self.kbid_to_text[kbid] = ' '.join(text_list)
                
                # 删除文本中的特殊字符
                for c in ['\r', '\t', '\n']:
                    self.kbid_to_text[kbid] = self.kbid_to_text[kbid].replace(c, '')

                type_list = line['type'].split('|')
                self.kbid_to_types[kbid] = type_list
                for t in type_list:
                    if t not in self.type_to_idx:
                        self.type_to_idx[t] = len(self.idx_to_type)
                        self.idx_to_type.append(t)

        # 保存pickle文件
        pd.to_pickle(self.entity_to_kbids, PICKLE_PATH + 'ENTITY_TO_KBIDS.pkl')
        pd.to_pickle(self.kbid_to_entities, PICKLE_PATH + 'KBID_TO_ENTITIES.pkl')
        pd.to_pickle(self.kbid_to_text, PICKLE_PATH + 'KBID_TO_TEXT.pkl')
        pd.to_pickle(self.kbid_to_types, PICKLE_PATH + 'KBID_TO_TYPES.pkl')
        pd.to_pickle(self.kbid_to_predicates, PICKLE_PATH + 'KBID_TO_PREDICATES.pkl')
        pd.to_pickle(self.idx_to_type, PICKLE_PATH + 'IDX_TO_TYPE.pkl')
        pd.to_pickle(self.type_to_idx, PICKLE_PATH + 'TYPE_TO_IDX.pkl')
        logger.info('Process Pickle File Finish.')

生成模 训练,验证集,把数据保存在tsv文件中
class DataFramePreprocessor:
    """生成模型训练、验证、推断所需的tsv文件"""

    def __init__(self):
        pass

    def process_link_data(self, input_path, output_path, max_negs=-1):
       
        entity_to_kbids = PICKLE_DATA['ENTITY_TO_KBIDS']
        #print("entity_to_kbids")
        kbid_to_text = PICKLE_DATA['KBID_TO_TEXT']
        #print(kbid_to_text)
        kbid_to_predicates = PICKLE_DATA['KBID_TO_PREDICATES']
        link_dict = defaultdict(list)

        with open(input_path, 'r',encoding='utf-8') as f:
            for line in tqdm(f):
                line = json.loads(line)

                for data in line['mention_data']:
                    # 对测试集特殊处理
                    if 'kb_id' not in data:
                        data['kb_id'] = '0'

                    # KB中不存在的实体不进行链接
                    if not data['kb_id'].isdigit():
                        continue

                    entity = data['mention']
                    kbids = list(entity_to_kbids[entity])
                    random.shuffle(kbids)

                    num_negs = 0
                    for kbid in kbids:
                        if num_negs >= max_negs > 0 and kbid != data['kb_id']:
                            continue

                        link_dict['text_id'].append(line['text_id'])
                        link_dict['entity'].append(entity)
                        link_dict['offset'].append(data['offset'])
                        link_dict['short_text'].append(line['text'])
                        link_dict['kb_id'].append(kbid)
                        link_dict['kb_text'].append(kbid_to_text[kbid])
                        link_dict['kb_predicate_num'].append(len(kbid_to_predicates[kbid]))
                        if kbid != data['kb_id']:
                            link_dict['predict'].append(0)
                            num_negs += 1
                        else:
                            link_dict['predict'].append(1)

        link_data = pd.DataFrame(link_dict)
        link_data.to_csv(output_path, index=False, sep='\t')

    def process_type_data(self, input_path, output_path):
        kbid_to_types = PICKLE_DATA['KBID_TO_TYPES']
        type_dict = defaultdict(list)

        with open(input_path, 'r',encoding='utf-8') as f:
            for line in tqdm(f):
                line = json.loads(line)

                for data in line['mention_data']:
                    entity = data['mention']

                    # 测试集特殊处理
                    if 'kb_id' not in data:
                        entity_type = ['Other']
                    elif data['kb_id'].isdigit():
                        entity_type = kbid_to_types[data['kb_id']]
                    else:
                        entity_type = data['kb_id'].split('|')
                        for x in range(len(entity_type)):
                            entity_type[x] = entity_type[x][4:]
                    for e in entity_type:
                        type_dict['text_id'].append(line['text_id'])
                        type_dict['entity'].append(entity)
                        type_dict['offset'].append(data['offset'])
                        type_dict['short_text'].append(line['text'])
                        type_dict['type'].append(e)

        type_data = pd.DataFrame(type_dict)
        type_data.to_csv(output_path, index=False, sep='\t')

    def run(self):
        self.process_link_data(
            input_path=RAW_PATH + 'train.json',
            output_path=TSV_PATH + 'EL_TRAIN.tsv',
            max_negs=2,
        )
        logger.info('Process EL_TRAIN Finish.')
        self.process_link_data(
            input_path=RAW_PATH + 'dev.json',
            output_path=TSV_PATH + 'EL_VALID.tsv',
            max_negs=-1,
        )
        logger.info('Process EL_VALID Finish.')
        self.process_link_data(
            input_path=RAW_PATH + 'test.json',
            output_path=TSV_PATH + 'EL_TEST.tsv',
            max_negs=-1,
        )
        logger.info('Process EL_TEST Finish.')

        self.process_type_data(
            input_path=RAW_PATH + 'train.json',
            output_path=TSV_PATH + 'ET_TRAIN.tsv',
        )
        logger.info('Process ET_TRAIN Finish.')
        self.process_type_data(
            input_path=RAW_PATH + 'dev.json',
            output_path=TSV_PATH + 'ET_VALID.tsv',
        )
        logger.info('Process ET_VALID Finish.')
        self.process_type_data(
            input_path=RAW_PATH + 'test.json',
            output_path=TSV_PATH + 'ET_TEST.tsv',
        )
        logger.info('Process ET_TEST Finish.')
简单查看一下训练数据
train_data = pd.read_csv(TSV_PATH + 'ET_TRAIN.tsv', sep='\t')
train_data.head()
text_identityoffsetshort_texttype
01小品0小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人Other
11战狼故事3小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人Work
21吴京10小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人Person
31障碍16小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人Other
41爱人20小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人Other
简单查看一下效验数据
valid_data = pd.read_csv(TSV_PATH + 'ET_VALID.tsv', sep='\t')
valid_data.head()
text_identityoffsetshort_texttype
01天下没有不散的宴席0天下没有不散的宴席 - ╰つ雲中帆╰つWork
11╰つ雲中帆╰つ12天下没有不散的宴席 - ╰つ雲中帆╰つOther
22永嘉0永嘉厂房出租Location
32厂房2永嘉厂房出租Location
42出租4永嘉厂房出租Other
简单查看一下测试数据
test_data = pd.read_csv(TSV_PATH + 'ET_TEST.tsv', sep='\t')
test_data.head()
text_identityoffsetshort_texttype
01林平之0林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了Other
11岳灵珊5林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了Other
21师娘18林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了Other
31令狐冲21林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了Other
42思追0思追原来是个超级妹控,不愿妹妹嫁人,然而妹妹却喜欢一博老师Other
实体链接数据处理
class EntityTypingProcessor(DataProcessor):
    """实体链接数据处理"""

    def get_train_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='train',
        )

    def get_dev_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='valid',
        )

    def get_test_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='test',
        )

    def get_labels(self):
        return PICKLE_DATA['IDX_TO_TYPE']

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            text_a = line[1]
            text_b = line[3]
            label = line[-1]
            examples.append(InputExample(
                guid=guid,
                text_a=text_a,
                text_b=text_b,
                label=label,
            ))
        return examples

    def create_dataloader(self, examples, tokenizer, max_length=64,
                          shuffle=False, batch_size=64, use_pickle=False):
        pickle_name = 'ET_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'
        if use_pickle:
            features = pd.read_pickle(PICKLE_PATH + pickle_name)
        else:
            features = glue_convert_examples_to_features(
                examples,
                tokenizer,
                label_list=self.get_labels(),
                max_length=max_length,
                output_mode='classification',
            )
            pd.to_pickle(features, PICKLE_PATH + pickle_name)

        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features]),
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            shuffle=shuffle,
            batch_size=batch_size,
            num_workers=2,
        )
        return dataloader

    def generate_feature_pickle(self, max_length):
        tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        train_examples = self.get_train_examples(TSV_PATH + 'ET_TRAIN.tsv')
        valid_examples = self.get_dev_examples(TSV_PATH + 'ET_VALID.tsv')
        test_examples = self.get_test_examples(TSV_PATH + 'ET_TEST.tsv')

        self.create_dataloader(
            examples=train_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=True,
            batch_size=32,
            use_pickle=False,
        )
        self.create_dataloader(
            examples=valid_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=False,
        )
        self.create_dataloader(
            examples=test_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=False,
        )
实体链接数据处理
class EntityLinkingProcessor(DataProcessor):
    """实体链接数据处理"""

    def get_train_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='train',
        )

    def get_dev_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='valid',
        )

    def get_test_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='test',
        )

    def get_labels(self):
        return ['0', '1']

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            text_a = line[1] + ' ' + line[3]
            text_b = line[5]
            label = line[-1]
            examples.append(InputExample(
                guid=guid,
                text_a=text_a,
                text_b=text_b,
                label=label,
            ))
        return examples

    def create_dataloader(self, examples, tokenizer, max_length=384,
                          shuffle=False, batch_size=32, use_pickle=False):
        pickle_name = 'EL_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'
        if use_pickle:
            features = pd.read_pickle(PICKLE_PATH + pickle_name)
        else:
            features = glue_convert_examples_to_features(
                examples,
                tokenizer,
                label_list=self.get_labels(),
                max_length=max_length,
                output_mode='classification',                
            )

            pd.to_pickle(features, PICKLE_PATH + pickle_name)

        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features]),
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            shuffle=shuffle,
            batch_size=batch_size,
            num_workers=2,
        )
        return dataloader

    def generate_feature_pickle(self, max_length):
        tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
        print("asdgfsdfg")

        train_examples = self.get_train_examples(TSV_PATH + 'EL_TRAIN.tsv')
        valid_examples = self.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')
        test_examples = self.get_test_examples(TSV_PATH + 'EL_TEST.tsv')

        self.create_dataloader(
            examples=train_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=True,
            batch_size=32,
            use_pickle=False,
        )
        self.create_dataloader(
            examples=valid_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=False,
        )
        self.create_dataloader(
            examples=test_examples,
            tokenizer=tokenizer,
            max_length=max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=False,
        )

实体链接模型
class EntityLinkingModel(pl.LightningModule):
    """实体链接模型"""

    def __init__(self, max_length=384, batch_size=32, use_pickle=True):
        super(EntityLinkingModel, self).__init__()
        # 输入最大长度
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_pickle = use_pickle
        
        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")       

        self.bert = BertForSequenceClassification.from_pretrained(
        "hfl/chinese-roberta-wwm-ext",
        num_labels = 1,
        )

        # 二分类损失函数
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, token_type_ids):
        logits = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[0]
        return logits.squeeze()

    def prepare_data(self):
   
        self.processor = EntityLinkingProcessor()
        self.train_examples = self.processor.get_train_examples(TSV_PATH + 'EL_TRAIN.tsv')
        self.valid_examples = self.processor.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')
        self.test_examples = self.processor.get_test_examples(TSV_PATH + 'EL_TEST.tsv')

        self.train_loader = self.processor.create_dataloader(
            examples=self.train_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=True,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        self.valid_loader = self.processor.create_dataloader(
            examples=self.valid_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        self.test_loader = self.processor.create_dataloader(
            examples=self.test_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        print("finish")

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader
实体链接推断
class EntityLinkingPredictor:

    def __init__(self, ckpt_name, batch_size=8, use_pickle=True):
        self.ckpt_name = ckpt_name
        self.batch_size = batch_size
        self.use_pickle = use_pickle

    def generate_tsv_result(self, tsv_name, tsv_type='Valid'):
        processor = EntityLinkingProcessor()
        tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        if tsv_type == 'Valid':
            examples = processor.get_dev_examples(TSV_PATH + tsv_name)
        elif tsv_type == 'Test':
            examples = processor.get_test_examples(TSV_PATH + tsv_name)
        else:
            raise ValueError('tsv_type error')
        dataloader = processor.create_dataloader(
            examples=examples,
            tokenizer=tokenizer,
            max_length=384,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )

        model = EntityLinkingModel.load_from_checkpoint(
            checkpoint_path=CKPT_PATH + self.ckpt_name,
        )
        model.to(DEVICE)
        model = nn.DataParallel(model)
        model.eval()

        result_list, logit_list = [], []
        for batch in tqdm(dataloader):
            for i in range(len(batch)):
                batch[i] = batch[i].to(DEVICE)

            input_ids, attention_mask, token_type_ids, labels = batch
            logits = model(input_ids, attention_mask, token_type_ids)
            preds = (logits > 0).int()

            result_list.extend(preds.tolist())
            logit_list.extend(logits.tolist())

        tsv_data = pd.read_csv(TSV_PATH + tsv_name, sep='\t')
        tsv_data['logits'] = logit_list
        tsv_data['result'] = result_list
        result_name = tsv_name.split('.')[0] + '_RESULT.tsv'
        tsv_data.to_csv(RESULT_PATH + result_name, index=False, sep='\t')
实体类型推断模型
import torch.nn as nn

class EntityTypingModel(pl.LightningModule):
    """实体类型推断模型"""

    def __init__(self, max_length=64, batch_size=64, use_pickle=True):
        super(EntityTypingModel, self).__init__()
        # 输入最大长度
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_pickle = use_pickle
        # 二分类损失函数
        self.criterion = nn.CrossEntropyLoss()

        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        # 预训练模型
        self.bert = BertForSequenceClassification.from_pretrained(
            "hfl/chinese-roberta-wwm-ext",
            num_labels=len(PICKLE_DATA['IDX_TO_TYPE']),
        )


    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[0]

    def prepare_data(self):
        self.processor = EntityTypingProcessor()
        self.train_examples = self.processor.get_train_examples(TSV_PATH + 'ET_TRAIN.tsv')
        self.valid_examples = self.processor.get_dev_examples(TSV_PATH + 'ET_VALID.tsv')
        self.test_examples = self.processor.get_test_examples(TSV_PATH + 'ET_TEST.tsv')

        self.train_loader = self.processor.create_dataloader(
            examples=self.train_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=True,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        self.valid_loader = self.processor.create_dataloader(
            examples=self.valid_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        self.test_loader = self.processor.create_dataloader(
            examples=self.test_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        outputs = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(outputs, labels)

        _, preds = torch.max(outputs, dim=1)
        acc = (preds == labels).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        outputs = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(outputs, labels)

        _, preds = torch.max(outputs, dim=1)
        acc = (preds == labels).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader

实体类型推断
class EntityTypingPredictor:

    def __init__(self, ckpt_name, batch_size=8, use_pickle=True):
        self.ckpt_name = ckpt_name
        self.batch_size = batch_size
        self.use_pickle = use_pickle

    def generate_tsv_result(self, tsv_name, tsv_type='Valid'):
        processor = EntityTypingProcessor()
        tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        if tsv_type == 'Valid':
            examples = processor.get_dev_examples(TSV_PATH + tsv_name)
        elif tsv_type == 'Test':
            examples = processor.get_test_examples(TSV_PATH + tsv_name)
        else:
            raise ValueError('tsv_type error')
        dataloader = processor.create_dataloader(
            examples=examples,
            tokenizer=tokenizer,
            max_length=64,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )

        model = EntityTypingModel.load_from_checkpoint(
            checkpoint_path=CKPT_PATH + self.ckpt_name,
        )
        model.to(DEVICE)
        model = nn.DataParallel(model)
        model.eval()

        result_list = []
        for batch in tqdm(dataloader):
            for i in range(len(batch)):
                batch[i] = batch[i].to(DEVICE)

            input_ids, attention_mask, token_type_ids, labels = batch
            outputs = model(input_ids, attention_mask, token_type_ids)
            _, preds = torch.max(outputs, dim=1)
            result_list.extend(preds.tolist())

        idx_to_type = PICKLE_DATA['IDX_TO_TYPE']
        result_list = [idx_to_type[x] for x in result_list]
        tsv_data = pd.read_csv(TSV_PATH + tsv_name, sep='\t')
        tsv_data['result'] = result_list
        result_name = tsv_name.split('.')[0] + '_RESULT.tsv'
        tsv_data.to_csv(RESULT_PATH + result_name, index=False, sep='\t')
导入数据
def preprocess_pickle_file():
    processor = PicklePreprocessor()
    processor.run()


def preprocess_tsv_file():
    processor = DataFramePreprocessor()
    processor.run()


def generate_feature_pickle():
    processor = EntityLinkingProcessor()
    processor.generate_feature_pickle(max_length=384)

    processor = EntityTypingProcessor()
    processor.generate_feature_pickle(max_length=64)

def train_entity_linking_model(ckpt_name):

    model = EntityLinkingModel(max_length=384, batch_size=32)
    trainer = pl.Trainer(
        max_epochs=1,
        gpus=1,
        distributed_backend='dp',
        default_save_path=EL_SAVE_PATH,
        profiler=True,
    )
    trainer.fit(model)
    trainer.save_checkpoint(CKPT_PATH + ckpt_name)

def train_entity_typing_model(ckpt_name):
    model = EntityTypingModel(max_length=64, batch_size=64)
    trainer = pl.Trainer(
        max_epochs=1,
        gpus=1,
        distributed_backend='dp',
        default_save_path=ET_SAVE_PATH,
        profiler=True,
    )
    trainer.fit(model)
    trainer.save_checkpoint(CKPT_PATH + ckpt_name)


def generate_link_tsv_result(ckpt_name):
    predictor = EntityLinkingPredictor(ckpt_name, batch_size=24, use_pickle=True)
    predictor.generate_tsv_result('EL_VALID.tsv', tsv_type='Valid')
    predictor.generate_tsv_result('EL_TEST.tsv', tsv_type='Test')


def generate_type_tsv_result(ckpt_name):
    predictor = EntityTypingPredictor(ckpt_name, batch_size=64, use_pickle=True)
    predictor.generate_tsv_result('ET_VALID.tsv', tsv_type='Valid')
    predictor.generate_tsv_result('ET_TEST.tsv', tsv_type='Test')


def make_predication_result(input_name, output_name, el_ret_name, et_ret_name):
    entity_to_kbids = PICKLE_DATA['ENTITY_TO_KBIDS']
    
    el_ret = pd.read_csv(
        RESULT_PATH + el_ret_name, sep='\t', dtype={
            'text_id': np.str_,
            'offset': np.str_,
            'kb_id': np.str_
        })
    
    et_ret = pd.read_csv(RESULT_PATH + et_ret_name, sep='\t', dtype={'text_id': np.str_, 'offset': np.str_})
 
    result = []
    with open(RAW_PATH + input_name, 'r',encoding="utf-8") as f:
        for line in tqdm(f):
            line = json.loads(line)
            for data in line['mention_data']:
                text_id = line['text_id']
                offset = data['offset']

                candidate_data = el_ret[(el_ret['text_id'] == text_id) & (el_ret['offset'] == offset)]
                # Entity Linking
                if len(candidate_data) > 0 and candidate_data['logits'].max() > 0:
                    max_idx = candidate_data['logits'].idxmax()
                    data['kb_id'] = candidate_data.loc[max_idx]['kb_id']
                # Entity Typing
                else:
                    type_data = et_ret[(et_ret['text_id'] == text_id) & (et_ret['offset'] == offset)]
                    data['kb_id'] = 'NIL_' + type_data.iloc[0]['result']
            result.append(line)

    with open(RESULT_PATH + output_name, 'w',encoding="utf-8") as f:
        for r in result:
            json.dump(r, f, ensure_ascii=False)
            f.write('\n')



set_random_seed(20200619)
preprocess_pickle_file()
preprocess_tsv_file()
generate_feature_pickle()

train_entity_linking_model('EL_BASE_EPOCH0.ckpt')
generate_link_tsv_result('EL_BASE_EPOCH0.ckpt')
train_entity_typing_model('ET_BASE_EPOCH1.ckpt')
generate_type_tsv_result('ET_BASE_EPOCH1.ckpt')
简单查看一下测试数据的准确率
el_ret = pd.read_csv("./data/result/ET_TEST_RESULT.tsv", sep='\t')
el_ret.head()
text_identityoffsetshort_texttyperesult
01林平之0林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了OtherPerson
11岳灵珊5林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了OtherPerson
21师娘18林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了OtherOther
31令狐冲21林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了OtherPerson
42思追0思追原来是个超级妹控,不愿妹妹嫁人,然而妹妹却喜欢一博老师OtherPerson
el_ret = pd.read_csv("./data/result/ET_VALID_RESULT.tsv", sep='\t')        
el_ret.head()
text_identityoffsetshort_texttyperesult
01天下没有不散的宴席0天下没有不散的宴席 - ╰つ雲中帆╰つWorkWork
11╰つ雲中帆╰つ12天下没有不散的宴席 - ╰つ雲中帆╰つOtherPerson
22永嘉0永嘉厂房出租LocationLocation
32厂房2永嘉厂房出租LocationOther
42出租4永嘉厂房出租OtherOther
path = "./data/result/ET_VALID_RESULT.tsv"        
data = pd.read_csv(
    path, sep='\t', dtype={
        'text_id': np.str_,
        'offset': np.str_,
        'kb_id': np.str_
    })
data.head()
text_identityoffsetshort_texttyperesult
01天下没有不散的宴席0天下没有不散的宴席 - ╰つ雲中帆╰つWorkWork
11╰つ雲中帆╰つ12天下没有不散的宴席 - ╰つ雲中帆╰つOtherPerson
22永嘉0永嘉厂房出租LocationLocation
32厂房2永嘉厂房出租LocationOther
42出租4永嘉厂房出租OtherOther
make_predication_result('dev.json', 'valid_result.json', 'EL_VALID_RESULT.tsv', 'ET_VALID_RESULT.tsv')
10000it [13:17, 12.55it/s]
make_predication_result('test.json', 'test_result.json', 'EL_TEST_RESULT.tsv', 'ET_TEST_RESULT.tsv')
10000it [31:54,  5.22it/s]
计算准确率
# !/bin/env python
# -*- coding: utf-8 -*-
#####################################################################################
#
#  Copyright (c) CCKS 2020 Entity Linking Organizing Committee.
#  All Rights Reserved.
#
#####################################################################################
"""
@version 2020-03-30
@brief:
    Entity Linking效果评估脚本,评价指标Micro-F1
"""
# import sys

# reload(sys)
# sys.setdefaultencoding('utf-8')
import json
from collections import defaultdict


class Eval(object):
    """
    Entity Linking Evaluation
    """

    def __init__(self, golden_file_path, user_file_path):
        self.golden_file_path = golden_file_path
        self.user_file_path = user_file_path
        self.tp = 0
        self.fp = 0
        self.total_recall = 0
        self.errno = None

    def format_check(self, file_path):
        """
        文件格式验证
        :param file_path: 文件路径
        :return: Bool类型:是否通过格式检查,通过为True,反之False
        """
        flag = True
        for line in open(file_path,encoding='utf-8'):
            json_info = json.loads(line.strip())
            if 'text_id' not in json_info:
                flag = False
                self.errno = 1
                break
            if 'text' not in json_info:
                flag = False
                self.errno = 2
                break
            if 'mention_data' not in json_info:
                flag = False
                self.errno = 3
                break
            if not json_info['text_id'].isdigit():
                flag = False
                self.errno = 5
                break           
            if not isinstance(json_info['mention_data'], list):
                flag = False
                self.errno = 7
                break
            for mention_info in json_info['mention_data']:
                if 'kb_id' not in mention_info:
                    flag = False
                    self.errno = 7
                    break
                if 'mention' not in mention_info:
                    flag = False
                    self.errno = 8
                    break
                if 'offset' not in mention_info:
                    flag = False
                    self.errno = 9
                    break                
                if not mention_info['offset'].isdigit():
                    flag = False
                    self.errno = 13
                    break
        return flag

    def micro_f1(self):
        """
        :return: float类型:精确率,召回率,Micro-F1值
        """
        # 文本格式验证
        flag_golden = self.format_check(self.golden_file_path)
        flag_user = self.format_check(self.user_file_path)
        # 格式验证失败直接返回None
        if not flag_golden or not flag_user:
            return None, None, None
        precision = 0
        recall = 0
        self.tp = 0
        self.fp = 0
        self.total_recall = 0
        golden_dict = defaultdict(list)
        for line in open(self.golden_file_path,encoding='utf-8'):
            golden_info = json.loads(line.strip())
            text_id = golden_info['text_id']
            text = golden_info['text']
            mention_data = golden_info['mention_data']
            for mention_info in mention_data:
                kb_id = mention_info['kb_id']
                mention = mention_info['mention']
                offset = mention_info['offset']
                key = '\1'.join([text_id, text, mention, offset]).encode('utf8')
                # value的第二个元素表示标志位,用于判断是否已经进行了统计
                golden_dict[key] = [kb_id, 0]
                self.total_recall += 1

        # 进行评估
        for line in open(self.user_file_path,encoding='utf-8'):
            golden_info = json.loads(line.strip())
            text_id = golden_info['text_id']
            text = golden_info['text']
            mention_data = golden_info['mention_data']
            for mention_info in mention_data:
                kb_id = mention_info['kb_id']
                mention = mention_info['mention']
                offset = mention_info['offset']
                key = '\1'.join([text_id, text, mention, offset]).encode('utf8')
                if key in golden_dict:
                    kb_result_golden = golden_dict[key]
                    if kb_id.isdigit():
                        if kb_id in [kb_result_golden[0]] and kb_result_golden[1] in [0]:
                            self.tp += 1
                        else:
                            self.fp += 1
                    else:
                        # nil golden结果
                        nil_res = kb_result_golden[0].split('|')
                        if kb_id in nil_res and kb_result_golden[1] in [0]:
                            self.tp += 1
                        else:
                            self.fp += 1
                    golden_dict[key][1] = 1
                else:
                    self.fp += 1
        if self.tp + self.fp > 0:
            precision = float(self.tp) / (self.tp + self.fp)
        if self.total_recall > 0:
            recall = float(self.tp) / self.total_recall
        a = 2 * precision * recall
        b = precision + recall
        if b == 0:
            return 0, 0, 0
        f1 = a / b
        return precision, recall, f1
eval = Eval('./data/ccks2020_el_data_v1/dev.json', './data/result/valid_result.json')

prec, recall, f1 = eval.micro_f1()
print(prec, recall, f1)
if eval.errno:
    print(eval.errno)

0.8488185641242852 0.8488185641242852 0.8488185641242852

相关链接:2020全国知识图谱与语义计算大会 http://sigkg.cn/ccks2020/?page_id=69

  • 5
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 14
    评论
文本命名实体标注(named entity recognition)是一种自然语言处理任务,旨在识别文本中具有特定命名实体的词或语。Python是一种强大的编程语言,提供了许多工具和库,可以帮助我们进行文本命名实体标注。 在Python中,有许多开源库可供使用,如NLTK(Natural Language Toolkit)、SpaCy和StanfordNERTagger等。这些库都提供了预训练的模型和API,可以直接用于文本命名实体标注。 使用NLTK库时,可以使用其内置的命名实体标注器,如Maxent命名实体标注器和CRF命名实体标注器。以下是一个使用Maxent命名实体标注器的示例代码: ``` import nltk sentence = "巴黎是法国的首都。" tokens = nltk.word_tokenize(sentence) tagged = nltk.pos_tag(tokens) entities = nltk.chunk.ne_chunk(tagged) for subtree in entities.subtrees(): if subtree.label() != 'S': print(subtree) ``` 以上代码将对给定的句子进行分词、词性标注和命名实体标注,并输出识别的命名实体。 另一个非常流行的库是SpaCy,它提供了一个非常快速和高效的命名实体标注器。以下是使用SpaCy库的示例代码: ``` import spacy nlp = spacy.load("en_core_web_sm") sentence = "巴黎是法国的首都。" doc = nlp(sentence) for entity in doc.ents: print(entity.text, entity.label_) ``` 这段代码使用了SpaCy提供的英文预训练模型,对给定的句子进行命名实体标注,并输出识别的命名实体及其标签。 总之,使用Python可以很方便地进行文本命名实体标注。借助于NLTK、SpaCy等开源库,我们可以轻松地实现这一任务,并从文本中识别出特定的实体

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值