import torch
import torch.nn as nn
import os
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
def get_data(path,sort_by_len=0,num=None):
#w我们这里引入一个参数,就是为了去给数据集进行一个排序,sort_by_len=0表示不排序,sort_by_len=1表示按照文本长度进行排序,num表示要取的样本数,这样训练集在取的时候也不用shuffle了。
#这样就可以避免长短文本再切割填充的时候有太多的垃圾数据,提高训练的效率。
all_text = []
all_label = []
with open(path,"r",encoding="utf8") as f:
all_data = f.read().split("\n")
if sort_by_len == True:
all_data = sorted(all_data,key=lambda x:len(x))
for data in all_data:
try:
if len(data) == 0:
continue
data_s = data.split(" ")
if len(data_s) != 2:
continue
text,label = data_s
label = int(label)
except Exception as e:
print(e)
else:
all_text.append(text)
all_label.append(int(label))
if num is None:
return all_text,all_label
else:
return all_text[:num], all_label[:num]
def build_word2index(train_text):
word_2_index = {"PAD":0,"UNK":1}
for text in train_text:
for word in text:
if word not in word_2_index:
word_2_index[word] = len(word_2_index)
return word_2_index
class TextDataset(Dataset):
def __init__(self,all_text,all_lable):
self.all_text = all_text
self.all_lable = all_lable
def __getitem__(self, index):
global word_2_index
text = self.all_text[index]
text_index = [word_2_index[i] for i in text]
label = self.all_lable[index]
text_len = len(text)
return text_index,label,text_len
def process_batch_batch(self, data):
global word_2_index
batch_text = []
batch_label = []
batch_len = []
for d in data:
batch_text.append(d[0]
1116Rnn复现
最新推荐文章于 2024-09-17 16:52:57 发布