Dataset、IterableDataset 读取大数据的思路

一、单进程读取数据

Dataset

在数据量很大,无法将全部数据加载到内存的情况下,可以在init中读出表数据行数,在__len__设置为长度返回,在__getitem__中根据idx读表,idx就可以表示读取的表的行数,一般在读表的时候写作 path/table_name?start={}&end={}

import torch
import numpy as np
from torch.utils.data import IterableDataset, Dataset

'''
需要先一次性把data都从文件或者表中读出来,知道数据的长度,为了生成index列表,长度为数据的长度
分batch训练的时候,dataloader根据分好的一个batch中的idx来读取这个batch中的数据
'''


a = [{'anchor_text': np.array([1, 1, 1]), 'anchor_vis': np.array([1, 1, 1])},
     {'anchor_text': np.array([2, 2, 1]), 'anchor_vis': np.array([4, 1, 1])},
     {'anchor_text': np.array([3, 3, 1]), 'anchor_vis': np.array([2, 1, 1])},
     {'anchor_text': np.array([4, 4, 1]), 'anchor_vis': np.array([3, 1, 1])}]


class TableDataset(Dataset):
    def __init__(self):
        self.tablepath = ''
        self.data_length = len(a)

    def __len__(self):
        return self.data_length

    def __getitem__(self, idx):
        return a[idx]


train_dataset = TableDataset()
train_ld = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)


for idx, batch_data in enumerate(train_ld):
    print(batch_data)
    at = batch_data['anchor_text'].to(torch.float32)
    # print("at--------",at)

IterableDataset

在数据量很大,无法将全部数据加载到内存的情况下,可以在__iter__中一行一行的读表,读一行就立马返回一行。

import torch
import numpy as np
from torch.utils.data import IterableDataset, Dataset
'''
不需要一次性知道数据长度
分batch训练的时候,dataloader根据一个batch的大小bs来执行__iter__函数bs次,得到这个batch的数据
'''

a = [{'anchor_text': np.array([1, 1, 1]), 'anchor_vis': np.array([1, 1, 1])},
     {'anchor_text': np.array([2, 2, 1]), 'anchor_vis': np.array([4, 1, 1])},
     {'anchor_text': np.array([3, 3, 1]), 'anchor_vis': np.array([2, 1, 1])},
     {'anchor_text': np.array([4, 4, 1]), 'anchor_vis': np.array([3, 1, 1])}]

class TableDataset2(IterableDataset):
    def __init__(self):
        self.tablepath = ''

    def __iter__(self):
        for line in a:
            print("line:",line)
            yield line


train_dataset = TableDataset2()
train_ld = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)


for idx, batch_data in enumerate(train_ld):
    print(batch_data)
    at = batch_data['anchor_text'].to(torch.float32)
    # print("at--------",at)

上述提到的处理数据量大的方法,都需要提前将数据处理好存入表中,程序读取数据就可以直接跑模型了。

二、多进程读取数据

后续有时间再补上

IterableDataset

当DataLoader设置为多进程时,每个进程都会拥有一个IterableDataset的生成器函数__iter__,每当这个进程收集到的数据达到batch size的时候,就把这批收集到的数据给loader,也就在for batch_data in train_loader: 的时候就能得到这批数据。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值