小黑NER探索:Data处理与Dataset建立,model定义与导入

import torch
import os
import time
import torch.nn as nn
import sys
sys.path.append('..')
from tqdm import tqdm
from model.BERT_BiLSTM_CRF import BERT_BiLSTM_CRF
from main import valid
class InputFeature(object):
    def __init__(self,input_id,label_id,input_mask,char,char_label):
        self.input_id = input_id
        self.label_id = label_id
        self.input_mask = input_mask
        self.char = char
        self.char_label = char_label
    def __repr__(self):
        return str({
            'input_id':self.input_id,
            'label_id':self.label_id,
            'input_mask':self.input_mask,
            'char':self.char,
            'char_label':self.char_label,
        })
def load_vocab(vocab_file):
    vocab = {}
    index = 0
    with open(vocab_file,'r',encoding = 'utf-8') as fp:
        while True:
            token = fp.readline()
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

def read_corpus(path,max_length,label_dic,vocab):
    max_len = 0
    with open(path,'r',encoding = 'utf-8') as fp:
        result = []
        words = []
        labels = []
        for line in fp:
            contends = line.strip()
            tokens = contends.split(' ')
            if len(tokens) == 2:
                words.append(tokens[0])
                labels.append(tokens[1])
            else:
                if len(contends) == 0 and len(words) > 0:
                    if len(words) > max_len:
                        max_len = len(words)
                    if len(words) > max_length - 2:
                        words = words[:(max_length-2)]
                        labels = labels[:(max_length-2)]
                    words = ['[CLS]'] + words + ['[SEP]']
                    labels = ['<START>'] + labels + ['<EOS>']
                    input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in words]
                    label_ids = [label_dic[i] for i in labels]
                    input_mask = [1] * len(input_ids)
                    
                    words = [i if i in vocab else int(vocab['[UNK]']) for i in words]
                    # 填充
                    if len(input_ids) < max_length:
                        input_ids.extend([0] * (max_length - len(input_ids)))
                        label_ids.extend([0] * (max_length - len(label_ids)))
                        input_mask.extend([0] * (max_length - len(input_mask)))
                        
                        words.extend([0] * (max_length - len(words)))
                        labels.extend([0] * (max_length - len(labels)))
                    assert len(input_ids) == max_length
                    assert len(label_ids) == max_length
                    assert len(input_mask) == max_length
                    assert len(words) == max_length
                    assert len(labels) == max_length
                    feature = InputFeature(input_id=input_ids,label_id = label_ids,input_mask=input_mask,char=words,char_label=labels)
                    result.append(feature)
                    words = []
                    labels = []
        print(max_len)
        return result
'''
train_path = '../dataset/poety_data/example.train'
vocab_path = '../bert-base-chinese/vocab.txt'
tag_file = '../dataset/poety_data/tag.txt'
label_dic = load_vocab(tag_file)
vocab = load_vocab(vocab_path)
result = read_corpus(train_path,286,label_dic,vocab)
'''
pass
from config import Config
from torch.utils.data import TensorDataset,DataLoader
device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device( 'cpu')
config = Config()
vocab = load_vocab(config.vocab)
label_dic = load_vocab(config.label_file)
tagset_size = len(label_dic)
train_data = read_corpus(config.train_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)
dev_data = read_corpus(config.dev_file,max_length = config.max_length,label_dic = label_dic,vocab = vocab)

train_ids = torch.LongTensor([temp.input_id for temp in train_data])
train_masks = torch.LongTensor([temp.input_mask for temp in train_data])
train_tags = torch.LongTensor([temp.label_id for temp in train_data])
train_dataset = TensorDataset(train_ids, train_masks, train_tags)
#print(list(train_dataset)[0])
train_loader = DataLoader(train_dataset,shuffle = True,batch_size = config.batch_size)
dev_ids = torch.LongTensor([temp.input_id for temp in dev_data])
dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data])
dev_tags = torch.LongTensor([temp.label_id for temp in dev_data])
dev_dataset = TensorDataset(dev_ids,dev_masks,dev_tags)
dev_loader = DataLoader(dev_dataset,shuffle = True,batch_size = config.batch_size)
model = BERT_BiLSTM_CRF(tagset_size,
                        config.bert_embedding,
                        config.rnn_hidden,
                        config.rnn_layer,
                        config.dropout,
                        config.pretrain_model_name,
                        device
                       ).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = config.lr,weight_decay = config.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode = 'max',factor = 0.5,patience = 1)
best_path = '../result/checkpoints/RoBERTa_result/RoBERTa_best.pth.tar'
checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model'])
_,valid_loss,start_estimator = valid(model,dev_loader)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值