from estimate import Precision,Recall,F1_score
import sys
import torch
import time
from tqdm import tqdm
from torch.utils.data import TensorDataset,DataLoader
sys.path.append('..')
from model.BERT_BiLSTM_CRF import BERT_BiLSTM_CRF
from config import Config
from utils import read_corpus,load_vocab
device = 'cuda'
config = Config()
vocab = load_vocab(config.vocab)
label_dic = load_vocab(config.label_file)
tagset_size = len(label_dic)
dev_data = read_corpus(config.dev_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)
dev_ids = torch.LongTensor([temp.input_id for temp in dev_data])
dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data])
dev_tags = torch.LongTensor([temp.label_id for temp in dev_data])
dev_dataset = TensorDataset(dev_ids,dev_masks,dev_tags)
dev_loader = DataLoader(dev_dataset,shuffle = True,batch_size = config.batch_size)
model = BERT_BiLSTM_CRF(tagset_size,
config.bert_embedding,
config.rnn_hidden,
config.rnn_layer,
config.dropout,
config.pretrain_model_name,
device
).to(device)
def valid(model,dataloader):
model.eval()
device = model.device
pre_output = []
true_output = []
epoch_start = time.time()
running_loss = 0.0
with torch.no_grad():
tqdm_batch_iterator = tqdm(dataloader)
for _,batch in enumerate(tqdm_batch_iterator):
# inputs:[batch_size,max_len]
# masks:[batch_size,max_len]
# tags:[batch_size,max_len]
inputs,masks,tags = batch
real_length = torch.sum(masks,dim = 1)
tmp = []
i = 0
for line in tags.numpy().tolist():
tmp.append(line[:real_length[i]])
i += 1
true_output.append(tmp)
inputs = inputs.to(device)
masks = masks.byte().to(device)
tags = tags.to(device)
# feats:[batch_size,max_len,num_labels]
feats = model(inputs,masks)
loss = model.loss(feats,tags,masks)
# batch_size个path
out_path = model.predict(feats,masks)
pre_output.append(out_path)
running_loss += loss.item()
epoch_time = time.time() - epoch_start
epoch_loss = running_loss / len(dataloader)
# pre_output,true_output:[num_epochs,batch_size] 个 path
precision = Precision(pre_output,true_output)
recall = Recall(pre_output,true_output)
f1_score = F1_score(precision,recall)
estimator = (precision,recall,f1_score)
return epoch_time,epoch_loss,estimator
best_path = '../result/checkpoints/RoBERTa_result/RoBERTa_best.pth.tar'
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model'])
valid(model,dev_loader)
输出
(8.039861917495728,
5.8888872762521105,
(0.6759776536312849, 0.7393939393939394, 0.7062651005908666))