pytorch加载大数据
本文介绍的数据特点:
- 数据量大,无法一次读取到内存中
- 数据存储在csv或者文本文件中(每一行是一个sample,包括feature和label)
要求:
- 每次读取一小块数据到内存
- 能够batch
- 能够shuffle
自定义MyDataset,继承torch.utils.data.Dataset,重写__init__(),__len__(),__getitem__(),增加initial()
import torch.utils.data as Data
import random
class MyDataset(Data.Dataset):
def __init__(self,file_path,nraws,shuffle=False):
"""
file_path: the path to the dataset file
nraws: each time put nraws sample into memory for shuffle
shuffle: whether the data need to shuffle
"""
file_raws = 0
# get the count of all samples
with open(file_path,'r') as f:
for _ in f:
file_raws+=1
self.file_path = file_path
self.file_raws = file_raws
self.nraws = nraws
self.shuffle = shuffle
def initial(self):
self.finput = open(self.file_path,'r')
self.samples = list()
# put nraw samples into memory
for _ in range(self.nraws):
data = self.finput.readline() # data contains the feature and label
if data:
self.samples.append(data)
else:
break
self.current_sample_num = len(self.samples)
self.index = list(range(self.current_sample_num))
if self.shuffle:
random.shuffle(self.samples)
def __len__(self):
return self.file_raws
def __getitem__(self,item):
idx = self.index[0]
data = self.samples[idx]
self.index = self.index[1:]
self.current_sample_num-=1
if self.current_sample_num<=0:
# all the samples in the memory have been used, need to get the new samples
for _ in range(self.nraws):
data = self.finput.readline() # data contains the feature and label
if data:
self.samples.append(data)
else:
break
self.current_sample_num = len(self.samples)
self.index = list(range(self.current_sample_num))
if self.shuffle:
random.shuffle(self.samples)
return data
if __name__=="__main__":
datapath = r"C:\Users\hn\Desktop\test.txt"
batch_size = 64
nraws = 1000
epoch = 3
train_dataset = MyDataset(datapath,nraws,shuffle=True)
for _ in range(epoch):
train_dataset.initial()
train_iter = Data.DataLoader(dataset = train_dataset, batch_size = batch_size)
for _,data in enumerate(train_iter):
print(data)