小黑NER探索:valid函数

from estimate import Precision,Recall,F1_score
import sys
import torch
import time
from tqdm import tqdm
from torch.utils.data import TensorDataset,DataLoader
sys.path.append('..')
from model.BERT_BiLSTM_CRF import BERT_BiLSTM_CRF
from config import Config
from utils import read_corpus,load_vocab
device = 'cuda'
config = Config()
vocab = load_vocab(config.vocab)
label_dic = load_vocab(config.label_file)
tagset_size = len(label_dic)
dev_data = read_corpus(config.dev_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)
dev_ids = torch.LongTensor([temp.input_id for temp in dev_data])
dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data])
dev_tags = torch.LongTensor([temp.label_id for temp in dev_data])
dev_dataset = TensorDataset(dev_ids,dev_masks,dev_tags)
dev_loader = DataLoader(dev_dataset,shuffle = True,batch_size = config.batch_size)
model = BERT_BiLSTM_CRF(tagset_size,
                        config.bert_embedding,
                        config.rnn_hidden,
                        config.rnn_layer,
                        config.dropout,
                        config.pretrain_model_name,
                        device
                       ).to(device)
def valid(model,dataloader):
    model.eval()
    device = model.device
    pre_output = []
    true_output = []
    epoch_start = time.time()
    running_loss = 0.0
    with torch.no_grad():
        tqdm_batch_iterator = tqdm(dataloader)
        for _,batch in enumerate(tqdm_batch_iterator):
            # inputs:[batch_size,max_len]
            # masks:[batch_size,max_len]
            # tags:[batch_size,max_len]
            inputs,masks,tags = batch
            real_length = torch.sum(masks,dim = 1)
            tmp = []
            i = 0
            for line in tags.numpy().tolist():
                tmp.append(line[:real_length[i]])
                i += 1
            true_output.append(tmp)
            
            inputs = inputs.to(device)
            masks = masks.byte().to(device)
            tags = tags.to(device)
            # feats:[batch_size,max_len,num_labels]
            feats = model(inputs,masks)
            loss = model.loss(feats,tags,masks)
            # batch_size个path
            out_path = model.predict(feats,masks)
            pre_output.append(out_path)
            
            running_loss += loss.item()
    epoch_time = time.time() - epoch_start
    epoch_loss = running_loss / len(dataloader)
    # pre_output,true_output:[num_epochs,batch_size] 个 path
    precision = Precision(pre_output,true_output)
    recall = Recall(pre_output,true_output)
    f1_score = F1_score(precision,recall)
    estimator = (precision,recall,f1_score)
    return epoch_time,epoch_loss,estimator
best_path = '../result/checkpoints/RoBERTa_result/RoBERTa_best.pth.tar'
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model'])
valid(model,dev_loader)

输出

(8.039861917495728,
5.8888872762521105,
(0.6759776536312849, 0.7393939393939394, 0.7062651005908666))

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值