import torch
import os
import time
import torch.nn as nn
import sys
sys.path.append('..')
from tqdm import tqdm
from model.BERT_BiLSTM_CRF import BERT_BiLSTM_CRF
from main import valid
class InputFeature(object):
def __init__(self,input_id,label_id,input_mask,char,char_label):
self.input_id = input_id
self.label_id = label_id
self.input_mask = input_mask
self.char = char
self.char_label = char_label
def __repr__(self):
return str({
'input_id':self.input_id,
'label_id':self.label_id,
'input_mask':self.input_mask,
'char':self.char,
'char_label':self.char_label,
})
def load_vocab(vocab_file):
vocab = {}
index = 0
with open(vocab_file,'r',encoding = 'utf-8') as fp:
while True:
token = fp.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def read_corpus(path,max_length,label_dic,vocab):
max_len = 0
with open(path,'r',encoding = 'utf-8') as fp:
result = []
words = []
labels = []
for line in fp:
contends = line.strip()
tokens = contends.split(' ')
if len(tokens) == 2:
words.append(tokens[0])
labels.append(tokens[1])
else:
if len(contends) == 0 and len(words) > 0:
if len(words) > max_len:
max_len = len(words)
if len(words) > max_length - 2:
words = words[:(max_length-2)]
labels = labels[:(max_length-2)]
words = ['[CLS]'] + words + ['[SEP]']
labels = ['<START>'] + labels + ['<EOS>']
input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in words]
label_ids = [label_dic[i] for i in labels]
input_mask = [1] * len(input_ids)
words = [i if i in vocab else int(vocab['[UNK]']) for i in words]
# 填充
if len(input_ids) < max_length:
input_ids.extend([0] * (max_length - len(input_ids)))
label_ids.extend([0] * (max_length - len(label_ids)))
input_mask.extend([0] * (max_length - len(input_mask)))
words.extend([0] * (max_length - len(words)))
labels.extend([0] * (max_length - len(labels)))
assert len(input_ids) == max_length
assert len(label_ids) == max_length
assert len(input_mask) == max_length
assert len(words) == max_length
assert len(labels) == max_length
feature = InputFeature(input_id=input_ids,label_id = label_ids,input_mask=input_mask,char=words,char_label=labels)
result.append(feature)
words = []
labels = []
print(max_len)
return result
'''
train_path = '../dataset/poety_data/example.train'
vocab_path = '../bert-base-chinese/vocab.txt'
tag_file = '../dataset/poety_data/tag.txt'
label_dic = load_vocab(tag_file)
vocab = load_vocab(vocab_path)
result = read_corpus(train_path,286,label_dic,vocab)
'''
pass
from config import Config
from torch.utils.data import TensorDataset,DataLoader
device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device( 'cpu')
config = Config()
vocab = load_vocab(config.vocab)
label_dic = load_vocab(config.label_file)
tagset_size = len(label_dic)
train_data = read_corpus(config.train_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)
dev_data = read_corpus(config.dev_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)
train_ids = torch.LongTensor([temp.input_id for temp in train_data])
train_masks = torch.LongTensor([temp.input_mask for temp in train_data])
train_tags = torch.LongTensor([temp.label_id for temp in train_data])
train_dataset = TensorDataset(train_ids, train_masks, train_tags)
#print(list(train_dataset)[0])
train_loader = DataLoader(train_dataset,shuffle = True,batch_size = config.batch_size)
dev_ids = torch.LongTensor([temp.input_id for temp in dev_data])
dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data])
dev_tags = torch.LongTensor([temp.label_id for temp in dev_data])
dev_dataset = TensorDataset(dev_ids,dev_masks,dev_tags)
dev_loader = DataLoader(dev_dataset,shuffle = True,batch_size = config.batch_size)
model = BERT_BiLSTM_CRF(tagset_size,
config.bert_embedding,
config.rnn_hidden,
config.rnn_layer,
config.dropout,
config.pretrain_model_name,
device
).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = config.lr,weight_decay = config.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode = 'max',factor = 0.5,patience = 1)
best_path = '../result/checkpoints/RoBERTa_result/RoBERTa_best.pth.tar'
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model'])
_,valid_loss,start_estimator = valid(model,dev_loader)
小黑NER探索:Data处理与Dataset建立,model定义与导入
最新推荐文章于 2023-08-08 01:10:57 发布