1116Rnn复现

import torch
import torch.nn as nn
import os
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader

def get_data(path,sort_by_len=0,num=None):
    #w我们这里引入一个参数,就是为了去给数据集进行一个排序,sort_by_len=0表示不排序,sort_by_len=1表示按照文本长度进行排序,num表示要取的样本数,这样训练集在取的时候也不用shuffle了。
    #这样就可以避免长短文本再切割填充的时候有太多的垃圾数据,提高训练的效率。
    all_text = []
    all_label = []
    with open(path,"r",encoding="utf8") as f:
        all_data = f.read().split("\n")
        if sort_by_len == True:
            all_data = sorted(all_data,key=lambda x:len(x))
    for data in all_data:
        try:
            if len(data) == 0:
                continue
            data_s = data.split("	")
            if len(data_s) != 2:
                continue
            text,label = data_s
            label = int(label)

        except Exception as e:
            print(e)
        else:
            all_text.append(text)
            all_label.append(int(label))
    if num is None:
            return all_text,all_label
    else:
        return all_text[:num], all_label[:num]

def build_word2index(train_text):
    word_2_index =  {"PAD":0,"UNK":1}
    for text in train_text:
        for word in text:
            if word not in word_2_index:
                word_2_index[word] = len(word_2_index)
    return word_2_index


class TextDataset(Dataset):
    def __init__(self,all_text,all_lable):
        self.all_text = all_text
        self.all_lable = all_lable

    def __getitem__(self, index):
        global word_2_index
        text = self.all_text[index]
        text_index = [word_2_index[i] for i in text]
        label = self.all_lable[index]
        text_len = len(text)
        return text_index,label,text_len


    def process_batch_batch(self, data):
        global word_2_index
        batch_text = []
        batch_label = []
        batch_len = []

        for d in data:
            batch_text.append(d[0]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值