Pytorch的DataLoader和Dataset以及TensorDataset的源码分析和使用

1.为什么要用DataLoader和Dataset

要对大量数据进行加载和处理时因为可能会出现内存不够用的情况,这时候就需要用到数据集类Dataset或TensorDataset和数据集加载类DataLoader了。使用这些类后可以将原本的数据分成小块,在需要使用的时候再一部分一本分读进内存中,而不是一开始就将所有数据读进内存中。

2.Dateset的使用

pytorch中的torch.utils.data.Dataset是表示数据集的抽象类,但它一般不直接使用,而是通过自定义一个数据集来使用。来自定义数据集应该继承Dataset并应该有实现返回数据集尺寸的__len__方法和用来获取索引数据的__getitem__方法。Dataset类的源码如下:

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

可以看到Dataset类中没有__len__方法,虽然有__getitem__方法,但是并没有实现啥有用的功能。所以要写一个Dataset类的子类来实现其应有的功能。

自定义类的实现举例:

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
import numpy as np
import pandas as pd

value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape)  # (73700, 300)
value_size = value_array.shape[0]  # 73700
train_size = int(0.7*value_size)

train_array = val_array[:train_size]  
train_label_array = val_array[60:train_size+60]

class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """

    def __init__(self, *arrays):
        assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
        self.arrays = arrays

    def __getitem__(self, index):
        return tuple(array[index] for array in self.arrays)

    def __len__(self):
        return self.arrays[0].shape[0]


# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
train_dataset = DealDataset(train_array, train_label_array)

train_loader2 = DataLoader(dataset=train_dataset,
                           batch_size=32,
                           shuffle=True)

for epoch in range(2):
    for i, data in enumerate(train_loader2):
        # 将数据从 train_loader 中读出来,一次读取的样本数是32个
        inputs, labels = data

        # 将这些数据转换成Variable类型
        inputs, labels = Variable(inputs), Variable(labels)

        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())


结果:

epoch: 0 的第 0 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 1 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 2 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 3 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 4 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 5 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
...

3.TensorDataset的使用

TensorDataset是可以直接使用的数据集类,它的源码如下:

class TensorDataset(Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

可以看到TensorDataset类是Dataset类的子类,且拥有返回数据集尺寸的__len__方法和用来获取索引数据的__getitem__方法,所以可以直接使用。它的结构跟上面自定义的子类的结构是一样的,惟一的不同是TensorDataset已经规定了传入的数据必须是torch.Tensor类型的,而自定义子类可以自由设定。

使用举例:

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
import numpy as np
import pandas as pd

value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape)  # (73700, 300)
value_size = value_array.shape[0]  # 73700
train_size = int(0.7*value_size)

train_array = val_array[:train_size]  
train_tensor = torch.tensor(train_array, dtype=torch.float32).to(device)
train_label_array = val_array[60:train_size+60]
train_labels_tensor = torch.tensor(train_label_array,dtype=torch.float32).to(device)

train_dataset = TensorDataset(train_tensor, train_labels_tensor)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=100,
                          shuffle=False,
                          num_workers=0)

for epoch in range(2):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        print(epoch, i, "inputs", inputs.data.size(), "labels", labels.data.size())

结果:

0 0 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 1 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 2 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 3 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 4 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 5 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 6 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 7 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 8 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 9 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 10 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
...
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

comli_cn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值