1123LSTM

 

import torch
import torch.nn as nn
import os
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader

def get_data(path,sort_by_len=False,num=None):
    all_text = []
    all_label = []
    with open(path,"r",encoding="utf8") as f:
        all_data = f.read().split("\n")
        if sort_by_len == True:
            all_data = sorted(all_data,key=lambda x:len(x))
    for data in all_data:
        try:
            if len(data) == 0:
                continue
            data_s = data.split("	")
            if len(data_s) != 2:
                continue
            text,label = data_s
            label = int(label)

        except Exception as e:
            print(e)
        else:
            all_text.append(text)
            all_label.append(int(label))
    if num is None:
            return all_text,all_label
    else:
        return all_text[:num], all_label[:num]

def build_word2index(train_text):
    word_2_index =  {"PAD":0,"UNK":1}
    for text in train_text:
        for word in text:
            if word not in word_2_index:
                word_2_index[word] = len(word_2_index)
    return word_2_index


class TextDataset(Dataset):
    def __init__(self,all_text,all_lable):
        self.all_text = all_text
        self.all_lable = all_lable

    def __getitem__(self, index):
        global word_2_index
        text = self.all_text[index]
        text_index = [word_2_index[i] for i in text]
        label = self.all_lable[index]
        text_len = len(text)
        return text_index,label,text_len


    def process_batch_batch(self, data):
        global word_2_index
        batch_text = []
        batch_label = []
        batch_len = []

        for d in data:
            batch_text.append(d[0])
            batch_label.append(d[1])
            batch_len.append(d[2])
        batch_max_len = max(batch_len)

        # batch_text = [i[:max_len] for i in batch_text]


        batch_text = [i + [0]*(batch_max_len-len(i)) for i in batch_text]
        # batch_onehot = []
        # for text_idx in batch_text:
        #     text_onehot = []
        #     for idx in text_idx:
        #         text_onehot.append(index_2_onehot(idx,len(word_2_index)))
        #     batch_onehot.append(text_onehot)
        return torch.tensor(batch_text),torch.tensor(batch_label)


    def __len__(self):
        return len(self.all_text)


class RNN_Model(nn.Module):
    def __init__(self,embedding_num,hidden_num):
        super().__init__()
        # self.embedding = nn.Embedding(corpus_len,embedding_num)
        self.hidden_num = hidden_num
        self.W = nn.Linear(embedding_num,hidden_num)
        self.U = nn.Linear(hidden_num,hidden_num)
        self.tanh = nn.ReLU()

    def forward(self,x):
        O = torch.zeros(x.shape[0],x.shape[1],self.hidden_num,device=x.device)
        t = torch.zeros(size=(x.shape[0], self.hidden_num),device=x.device)

        for i in range(x.shape[1]):
            w_emb = x[:,i]

            h = self.W(w_emb)
            h_ = h*0.2 + t*0.8
            h__ = self.tanh(h_)

            t = self.U(h__)
            O[:,i] = t


        return O,t


class LSTM_Model(nn.Module):
    def __init__(self,embedding_num,hidden_num):
        super().__init__()
        self.hidden_num = hidden_num#LSTM的隐藏层大小
        self.F = nn.Linear(embedding_num + hidden_num, hidden_num)
        self.I = nn.Linear(embedding_num + hidden_num, hidden_num)
        self.C = nn.Linear(embedding_num + hidden_num, hidden_num)
        self.O = nn.Linear(embedding_num + hidden_num, hidden_num)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self,batch_x,a_pre=None,c_pre=None,y=None): # x : batch * seq_len * embedding_num
        #a_pre,c_pre 他们是历史信息。
        if a_pre == None:
            a_pre = torch.zeros((batch_x.shape[0],self.hidden_num),device=batch_x.device,requires_grad=True)
        if c_pre == None:
            c_pre = torch.zeros((batch_x.shape[0],self.hidden_num),device=batch_x.device,requires_grad=True)

        letter_features = torch.zeros((*batch_x.shape[:2],hidden_num))

        # for i in range(x.shape[1]):
        # for batch_idx in range(batch_x.shape[0]):
        #     a_pre_b = a_pre[batch_idx:batch_idx+1]#这样就不用考虑维度了。
        #     c_pre_b = c_pre[batch_idx:batch_idx+1]
        for word_idx in range(batch_x.shape[1]):
            x = batch_x[:,word_idx]
            x_a = torch.cat((x,a_pre),dim=1)
            #这里x.shape是(10*128),a_pre.shape是(10*100),它们进行cat之后变成了(10*228)

            f_ = self.F.forward(x_a)
            i_ = self.I.forward(x_a)
            c_ = self.C.forward(x_a)
            o_ = self.O.forward(x_a)

            ft = self.sigmoid(f_)
            it = self.sigmoid(i_)
            cct = self.tanh(c_)
            ot = self.sigmoid(o_)

            c_next = ft * c_pre + it * cct
            th = self.tanh(c_next)
            a_next = ot * th

            a_pre = a_next
            c_pre = c_next

            letter_features[:,word_idx] = a_next
            # a_next.shape == (batch_x.shape[0],hidden_num),用它去做分类。
            #letter_features.shape == (batch_size,seq_len,hidden_num)


        return letter_features,(a_pre,c_pre)
    #这个letter_features是每句话的最后一个字的隐藏状态。  a_pre,c_pre是历史信息。不同于rnn,lstm有两个历史信息




class Model(nn.Module):
    def __init__(self,corpus_len,embedding_num,hidden_num,class_num):
        super().__init__()
        self.embedding = nn.Embedding(corpus_len,embedding_num)
        self.LSTM = LSTM_Model(embedding_num,hidden_num)
        self.classifier = nn.Linear(hidden_num,class_num)
        self.loss_fun = nn.CrossEntropyLoss()


    def forward(self,x,label=None):  # batch * sent_len
        x_emb = self.embedding(x) # x_emb : batch * sent_len * emb_num
        t,o = self.LSTM(x_emb)  # t : batch * 1 * hidden_num    o: batch * sent_len * hidden_num

        # t == o[:,-1,:]  == True
        pre = self.classifier(o[0])

        if label is not None:
            loss = self.loss_fun(pre,label)
            return  loss
        else:
            return torch.argmax(pre,dim=-1)



if __name__=="__main__":
    train_text, train_lable = get_data(os.path.join("..", "data", "文本分类", "train.txt"), True,10000)
    dev_text, dev_lable = get_data(os.path.join("..", "data", "文本分类", "dev.txt"), True,2000)

    word_2_index = build_word2index(train_text + dev_text)

    train_batch_size = 10
    embedding_num = 128
    hidden_num = 100
    epoch = 10
    lr = 0.001
    class_num = len(set(train_lable))
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    train_dataset = TextDataset(train_text, train_lable)
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False,collate_fn=train_dataset.process_batch_batch)

    dev_dataset = TextDataset(dev_text, dev_lable)
    dev_dataloader = DataLoader(dev_dataset, batch_size=10, shuffle=False, collate_fn=dev_dataset.process_batch_batch)

    model = Model(len(word_2_index),embedding_num,hidden_num,class_num).to(device)
    opt = torch.optim.Adam(model.parameters(), lr)

    for e in range(epoch):
        print("*" * 100)

        model.train()
        for bi, (batch_text, batch_label) in tqdm(enumerate(train_dataloader, start=1)):
            batch_text = batch_text.to(device)
            batch_label = batch_label.to(device)

            loss = model.forward(batch_text,batch_label)
            loss.backward()

            opt.step()
            opt.zero_grad()

            if bi % 50 == 0:
                print(f"loss:{loss:.2f}")

        model.eval()
        right_num = 0
        for bi, (batch_text, batch_label) in tqdm(enumerate(dev_dataloader, start=1)):
            batch_text = batch_text.to(device)
            batch_label = batch_label.to(device)
            pre = model.forward(batch_text)

            right_num+=int(torch.sum(pre==batch_label))

        acc = right_num/len(dev_dataset)
        print(f"acc:{acc*100:.3f}%")







 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值