评估通用教程-huggingface-datasets

评估通用教程-huggingface-datasets

使用datasets先进行预测,然后在计算指标。

# from models.ner import NER
import torch
from datasets_utils.load_datasets import Datasets
from config import config

'''
必须包含方法:
predict:利用torch.no_grad()、datasets.map对验证集进行预测。
evaluate:根据预测结果进行统计计算。


根据任务情况,自己定义函数:
extract_entities: 识别一句话中所有实体。
extract_single:extract_entities方法的辅助
get_real_entities:利用extract_entities取出一句话中的真实实体。
get_pred_entities: 利用模型预测,使用extract_entities取出一句话中的真实实体


'''


class Evaluate:
    def __init__(self, model_path):
        self.model = torch.load(model_path)
        self.valid_datasets = Datasets().load_datasets('valid')
        B_dis, I_dis = config.tags_to_index['B-dis'], config.tags_to_index['I-dis']
        B_sym, I_sym = config.tags_to_index['B-sym'], config.tags_to_index['I-sym']
        self.next_tags = {B_sym:I_sym, B_dis:I_dis}
        self.tags_class = {B_sym:'SYM', B_dis:'DIS'}
        self.real_entities = {'SYM':[], 'DIS':[]}
        self.pred_entities = {'SYM':[], 'DIS':[]}
        self.reuslt = {'SYM': [], 'DIS': []}
        self.correct_num = 0
        self.incorrect_num = 0
        self.total_num = 0


    def extract_entities(self, text, labels):
        index = 0
        while index < len(labels):
            label = labels[index]
            if label in self.next_tags.keys():
                entity, index = self.extract_single(index, text, labels)
                self.reuslt[self.tags_class[label]].append(entity)
                continue
            index += 1


    def extract_single(self, index, text, labels):
        entity = [text[index]]
        next_tag = self.next_tags[labels[index]]
        index += 1
        while index < len(labels):
            if labels[index] != next_tag:
                break
            entity.append(text[index])
            index += 1
        return ''.join(entity), index


    def get_real_entities(self, batch):
        text = batch['text']
        labels = batch['label']
        text = ''.join(text.split())
        self.extract_entities(text, labels)

    def get_pred_entities(self, batch):
        text = batch['text']
        labels = self.model.predict(text)[0]
        text = ''.join(text.split())
        self.extract_entities(text, labels)



    def predict(self):
        with torch.no_grad():
            self.valid_datasets.map(self.get_real_entities, batched=False)
            self.real_entities = self.reuslt
            self.reuslt = {'SYM': [], 'DIS': []}
            self.valid_datasets.map(self.get_pred_entities, batched=False)
            self.pred_entities = self.reuslt


    def evaluate(self):
        self.predict()
        for key in self.real_entities:
            self.total_num += len(self.real_entities[key])
            for pre_word in self.pred_entities[key]:
                if pre_word in self.real_entities[key]:
                    self.correct_num += 1
                else:
                    self.incorrect_num += 1
        print('召回率:',round(self.correct_num / self.total_num, 4))
        print('精准率:',round(self.correct_num / (self.correct_num+self.incorrect_num), 4))



if __name__ == '__main__':
    with torch.device(0):
        evaluate = Evaluate('../models_saved/ner/ner_best.bin')
        evaluate.evaluate()
        print(evaluate.pred_entities, evaluate.real_entities)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值