从零学PyTorch:DataLoader构建高效的自定义数据集

Torch中可以创建一个DataSet对象,并与dataloader一起使用,在训练模型时不断为模型提供数据Torch中DataLoader的参数如下

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

其中最重要的参数是dataset,是一个抽象类,包含两种类型:map-style datasets 和 iterable-style datasets.

1、map-style datasets 映射样式

借助映射样式构建数据集有两种方法,一种是构建dateset类,另外一种是借助TensorDataset直接将数据包装成dataset类,再传入到dataloader.

第一种方法:构建dateset类

该类型的dataset,其所有的子类必须重写__getitem__() 方法和__len()__方法:

(1)其中__getitem__函数的作用是根据索引index遍历数据 (2)__len__函数的作用是返回数据集的长度 (3)在创建的dataset类中可根据自己的需求对数据进行处理。可编写独立的数据处理函数,在__getitem__函数中进行调用;或者直接将数据处理方法写在__getitem__函数中或者__init__函数中,但__getitem__必须根据index返回响应的值,该值会通过index传到dataloader中进行后续的batch批处理。

def __getitem__(self, index):
    return self.src[index], self.trg[index]
def __len__(self):
    return len(self.src)

以时间序列使用示例,输入3个时间步,输出1个时间步,batchsize为5:

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

class GetTrainTestData(Dataset):
    def __init__(self, input_len, output_len, train_rate, is_train=True):
        super().__init__()
        # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
        # 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
        self.x = torch.sin(torch.arange(1, 1000, 0.1))
        self.sample_num = len(self.x)
        self.input_len = input_len
        self.output_len = output_len
        self.train_rate = train_rate
        self.src,  self.trg = [], []
        if is_train:
            for i in range(int(self.sample_num*train_rate)-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        else:
            for i in range(int(self.sample_num*train_rate), self.sample_num-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        print(len(self.src), len(self.trg))

    def __getitem__(self, index):
        return self.src[index], self.trg[index]

    def __len__(self):
        return len(self.src)  # 或者return len(self.trg), src和trg长度一样


data_train = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=True)
data_test = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=False)
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
# 对训练集:(不太清楚enumerate返回什么的时候就多print试试)
for i_batch, batch_data in enumerate(data_loader_train):
    print(i_batch)  # 打印batch编号
    print(batch_data[0].size())  # 打印该batch里面src
    print(batch_data[1].size())  # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
    print(i_batch)  # 打印batch编号
    print(src.size())  # 打印该batch里面src的尺寸
    print(trg.size())  # 打印该batch里面trg的尺寸    

output:

7988 7988
1994 1994
0
torch.Size([5, 3])
torch.Size([5, 1])
1
torch.Size([5, 3])
torch.Size([5, 1])
...

第二种方法:借助TensorDataset直接将数据包装成dataset类

另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset

src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))

data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
    print(i_batch)  # 打印batch编号
    print(batch_data[0].size())  # 打印该batch里面src
    print(batch_data[1].size())  # 打印该batch里面trg

output:

0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
...

2、Iterable-style datasets可迭代样式

可迭代样式的数据集是IterableDataset的一个实例,该实例必须重写__iter__方法,该方法用于对数据集进行迭代。这种类型的数据集特别适合随机读取数据不太可能实现的情况,并且批处理大小batchsize取决于获取的数据。比如读取数据库,远程服务器或者实时日志等数据的时候,可使用该样式,一般时序数据不使用这种样式。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

当交通遇上机器学习

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值