pytorch加载大数据

pytorch加载大数据

本文介绍的数据特点:

  1. 数据量大,无法一次读取到内存中
  2. 数据存储在csv或者文本文件中(每一行是一个sample,包括feature和label)

要求:

  1. 每次读取一小块数据到内存
  2. 能够batch
  3. 能够shuffle

 

自定义MyDataset,继承torch.utils.data.Dataset,重写__init__(),__len__(),__getitem__(),增加initial()

 

import torch.utils.data as Data
import random
 
class MyDataset(Data.Dataset):
    def __init__(self,file_path,nraws,shuffle=False):
        """
        file_path: the path to the dataset file
        nraws: each time put nraws sample into memory for shuffle
        shuffle: whether the data need to shuffle
        """
        file_raws = 0 
        # get the count of all samples
        with open(file_path,'r') as f:
            for _ in f:
                file_raws+=1
        self.file_path = file_path
        self.file_raws = file_raws
        self.nraws = nraws
        self.shuffle = shuffle
 
    def initial(self):
        self.finput = open(self.file_path,'r')
        self.samples = list()
 
        # put nraw samples into memory
        for _ in range(self.nraws):
            data = self.finput.readline()   # data contains the feature and label
            if data:
                self.samples.append(data)
            else:
                break
        self.current_sample_num = len(self.samples)
        self.index = list(range(self.current_sample_num))
        if self.shuffle:
            random.shuffle(self.samples)
 
    def __len__(self):
        return self.file_raws
 
    def __getitem__(self,item):
        idx = self.index[0]
        data = self.samples[idx]
        self.index = self.index[1:]
        self.current_sample_num-=1
 
        if self.current_sample_num<=0:
        # all the samples in the memory have been used, need to get the new samples
            for _ in range(self.nraws):
                data = self.finput.readline()   # data contains the feature and label
                if data:
                    self.samples.append(data)
                else:
                    break
            self.current_sample_num = len(self.samples)
            self.index = list(range(self.current_sample_num))
            if self.shuffle:
                random.shuffle(self.samples)
 
        return data
 
if __name__=="__main__":
    datapath = r"C:\Users\hn\Desktop\test.txt"
    batch_size = 64
    nraws = 1000
    epoch = 3
    train_dataset = MyDataset(datapath,nraws,shuffle=True)
    for _ in range(epoch):
        train_dataset.initial()
        train_iter = Data.DataLoader(dataset = train_dataset, batch_size = batch_size)
        for _,data in enumerate(train_iter):
            print(data)

 

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值