评估通用教程-huggingface-datasets
使用datasets先进行预测,然后在计算指标。
# from models.ner import NER
import torch
from datasets_utils.load_datasets import Datasets
from config import config
'''
必须包含方法:
predict:利用torch.no_grad()、datasets.map对验证集进行预测。
evaluate:根据预测结果进行统计计算。
根据任务情况,自己定义函数:
extract_entities: 识别一句话中所有实体。
extract_single:extract_entities方法的辅助
get_real_entities:利用extract_entities取出一句话中的真实实体。
get_pred_entities: 利用模型预测,使用extract_entities取出一句话中的真实实体
'''
class Evaluate:
def __init__(self, model_path):
self.model = torch.load(model_path)
self.valid_datasets = Datasets().load_datasets('valid')
B_dis, I_dis = config.tags_to_index['B-dis'], config.tags_to_index['I-dis']
B_sym, I_sym = config.tags_to_index['B-sym'], config.tags_to_index['I-sym']
self.next_tags = {B_sym:I_sym, B_dis:I_dis}
self.tags_class = {B_sym:'SYM', B_dis:'DIS'}
self.real_entities = {'SYM':[], 'DIS':[]}
self.pred_entities = {'SYM':[], 'DIS':[]}
self.reuslt = {'SYM': [], 'DIS': []}
self.correct_num = 0
self.incorrect_num = 0
self.total_num = 0
def extract_entities(self, text, labels):
index = 0
while index < len(labels):
label = labels[index]
if label in self.next_tags.keys():
entity, index = self.extract_single(index, text, labels)
self.reuslt[self.tags_class[label]].append(entity)
continue
index += 1
def extract_single(self, index, text, labels):
entity = [text[index]]
next_tag = self.next_tags[labels[index]]
index += 1
while index < len(labels):
if labels[index] != next_tag:
break
entity.append(text[index])
index += 1
return ''.join(entity), index
def get_real_entities(self, batch):
text = batch['text']
labels = batch['label']
text = ''.join(text.split())
self.extract_entities(text, labels)
def get_pred_entities(self, batch):
text = batch['text']
labels = self.model.predict(text)[0]
text = ''.join(text.split())
self.extract_entities(text, labels)
def predict(self):
with torch.no_grad():
self.valid_datasets.map(self.get_real_entities, batched=False)
self.real_entities = self.reuslt
self.reuslt = {'SYM': [], 'DIS': []}
self.valid_datasets.map(self.get_pred_entities, batched=False)
self.pred_entities = self.reuslt
def evaluate(self):
self.predict()
for key in self.real_entities:
self.total_num += len(self.real_entities[key])
for pre_word in self.pred_entities[key]:
if pre_word in self.real_entities[key]:
self.correct_num += 1
else:
self.incorrect_num += 1
print('召回率:',round(self.correct_num / self.total_num, 4))
print('精准率:',round(self.correct_num / (self.correct_num+self.incorrect_num), 4))
if __name__ == '__main__':
with torch.device(0):
evaluate = Evaluate('../models_saved/ner/ner_best.bin')
evaluate.evaluate()
print(evaluate.pred_entities, evaluate.real_entities)