2023-03-21干活小计

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from seqeval.metrics import f1_score,precision_score,recall_score,classification_report

def read_data(file):
    with open(file,"r",encoding="utf-8") as f:
        all_data = f.read().split("\n")

    sentences = []
    labels = []

    sentence = []
    label = []
    for data in all_data:
        data_s = data.split(" ")

        if len(data_s) != 2 :
            if len(sentence)>0  and len(label) > 0:
                sentences.append(sentence)
                labels.append(label)

            sentence = []
            label = []
            continue

        sent,l = data_s
        sentence.append(sent)
        label.append(l)
    return sentences,labels

def build_word(train_text):
    word_2_index = {"PAD":0,"UNK":1}

    for text in train_text:
        for w in text:
            word_2_index[w] = word_2_index.get(w,len(word_2_index))
    return word_2_index


def build_tag(train_tag):
    tag_2_index = {"PAD":0,"UNK":1,"O":2}

    for text in train_tag:
        for w in text:
            tag_2_index[w] = tag_2_index.get(w,len(tag_2_index))
    return tag_2_index

class NDataset(Dataset):
    def __init__(self,all_text,all_tag,word_2_index,tag_2_index,max_len,is_dev=False):
        self.all_text = all_text
        self.all_tag = all_tag
        self.word_2_index = word_2_index
        self.tag_2_index = tag_2_index
        self.max_len = max_len
        self.is_dev = is_dev

    def __getitem__(self,index):
        max_len = self.max_len
        text = self.all_text[index]
        if self.is_dev:
            max_len = len(text)

        text = text[:max_len]
        tag = self.all_tag[index][:max_len]
        assert len(text) == len(tag)
        text_len = len(tag)

        text_idx = [self.word_2_index.get(i,1) for i in text]
        tag_idx = [self.tag_2_index.get(i,1) for i in tag]

        text_idx = text_idx + [0] * (max_len-len(text_idx))
        tag_idx = tag_idx + [0] * (max_len-len(tag_idx))

        return torch.tensor(text_idx),torch.tensor(tag_idx),text_len

    def __len__(self):
        return len(self.all_tag)

class NModel(nn.Module):
    def __init__(self,corpus_len,embedding_num,tag_num):
        super().__init__()
        self.embedding = nn.Embedding(corpus_len,embedding_num)
        self.rnn = nn.GRU(embedding_num,150,batch_first=True,bidirectional=True)

        self.classifier = nn.Linear(300,tag_num)
        self.loss_fun = nn.CrossEntropyLoss()

    def forward(self,x,batch_label=None):
        batch_size,seq_len = x.shape
        x = self.embedding(x)
        x,_ = self.rnn(x)
        pre = self.classifier(x)

        # if batch_label is not None:
        if batch_label != None:
            loss = self.loss_fun(pre.reshape(batch_size*seq_len,-1),batch_label.reshape(-1))
            return loss
        else:
            return torch.argmax(pre,dim=-1)

if __name__ == "__main__": # f1-score , accuracy , precision , recall
    # B I E S O
    # B I E  O
    # B I O
    train_text,train_label = read_data(os.path.join("..","data","ner","BIO","train.txt"))
    dev_text,dev_label = read_data(os.path.join("..","data","ner","BIO","dev.txt"))

    word_2_index = build_word(train_text)
    tag_2_index = build_tag(train_label)
    index_2_tag = list(tag_2_index)

    max_len = 25
    batch_size = 10
    epoch = 20
    lr = 0.001
    embedding_num = 150
    tag_num = len(tag_2_index)
    corpus_len = len(word_2_index)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    train_dataset = NDataset(train_text,train_label,word_2_index,tag_2_index,max_len)
    train_dataloader = DataLoader(train_dataset,batch_size,shuffle=False)

    dev_dataset = NDataset(dev_text, dev_label, word_2_index, tag_2_index, max_len,is_dev=True)
    dev_dataloader = DataLoader(dev_dataset, 1, shuffle=False)

    model = NModel(corpus_len,embedding_num,tag_num).to(device)
    opt = torch.optim.AdamW(model.parameters(),lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt,5,gamma=0.8)
    for e in range(epoch):
        model.train()
        for batch_text_idx,batch_tag_idx,batch_len in train_dataloader:
            batch_text_idx = batch_text_idx.to(device)
            batch_tag_idx = batch_tag_idx.to(device)


            loss = model.forward(batch_text_idx,batch_tag_idx )

            loss.backward()
            opt.step()
            opt.zero_grad()


        lr_scheduler.step()
        # print(f"epoch:{e},loss:{loss:.2f}, lr:{opt.param_groups[0]['lr']}")

        model.eval()
        all_predict = []

        for batch_text_idx, batch_tag_idx,batch_len in dev_dataloader:
            batch_text_idx = batch_text_idx.to(device)
            batch_tag_idx = batch_tag_idx.to(device)

            pre = model.forward(batch_text_idx).tolist()
            pre = [index_2_tag[i] for i in pre[0]]
            all_predict.append(pre)

        # right_num = 0
        # all_num = 0
        #
        # all_not_O_num = 0
        # all_pre_not_O_num = 0
        #
        # for batch_text_idx, batch_tag_idx,batch_len in dev_dataloader:
        #     batch_text_idx = batch_text_idx.to(device)
        #     batch_tag_idx = batch_tag_idx.to(device)
        #
        #     pre = model.forward(batch_text_idx)
        #
        #     all_not_O_num += torch.sum(batch_tag_idx!=2)
        #     all_pre_not_O_num += torch.sum((pre == batch_tag_idx) & (pre!=2))
        #
        #     right_num += torch.sum((pre == batch_tag_idx))
        #     all_num += batch_tag_idx.shape[-1]
        # acc = right_num/all_num * 100
        # rec = all_pre_not_O_num/all_not_O_num * 100
        f1 = f1_score(dev_label,all_predict,average='macro')
        print(f"epoch:{e},f1:{f1*100:.2f}%")
        print(classification_report(dev_label,all_predict))

RNN NER

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值