pytorch深度学习和入门实战(二)Dataset和DataLoader使用详解

1.数据处理工具箱概述

数据下载和预处理是机器学习、深度学习实际项目中耗时又重要的任务,
尤其是数据预处理,关系到数据质量和模型性能,往往要占据项目的大部分时间。

PyTorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如图:
在这里插入图片描述
主要包括两大部分:
(1)torch.utils.data相关部分
torch.utils.data工具包,它包括以下4个类函数。
1)Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其
中的两个方法( getitem_()、len ())。
2)DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱
数据(shuffle)并提供并行加速等功能。
3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
4)*sampler:多种采样函数

(2)torchvision工具包相关部分
它包括4个类,各类的主要功能如下。
1)datasets:提供常用的数据集加载,设计上都是继承自
torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet和COCO
等。
2)models:提供深度学习中各种经典的网络结构以及训练好的模型
(如果选择pretrained=True),包括AlexNet、VGG系列、ResNet系列、
Inception系列等。
3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image
对象的操作。
4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个
网格中;另一个是save_img,它能将Tensor保存成图片。

2. torch.utils.data简介

utils.data包括DatasetDataLoader

2.1 torch.utils.data.Dataset为抽象类。

# 官方类的定义和相关说明
class Dataset(object):
    """An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError
        
    def __len__(self):
        raise NotImplementedError
        
    def __add__(self, other):
        return ConcatDataset([self, other])


    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。
首先__getitem__方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。
其次__len__ 方法返回数据集的总长度(训练集的总数)。

Dataset只支持两种类型的数据集:map-style datasets, iterable-style datasets.

2.2 utils.data.DataLoader

__getitem__一次只能获取一个数据,
所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。
data.DataLoader(
	dataset,
	batch_size=1,
	shuffle=False,
	sampler=None,
	batch_sampler=None,
	num_workers=0,
	collate_fn=<function default_collate at 0x7f108ee01620>,
	pin_memory=False,
	drop_last=False,
	timeout=0,
	worker_init_fn=None,
)
函数说明

·dataset:加载的数据集。
·batch_size:批大小。
·shuffle:是否将数据打乱。
·sampler:样本抽样。
·num_workers:使用多进程加载的进程数,0代表不使用多进程。
·collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。
·pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。
·drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

举例说明使用过程:

1)导入需要的模块

import torch
from torch.utils import data
import numpy as np

2)定义获取数据集的类。
该类继承基类Dataset,自定义一个数据集及对应标签

class TestDataset(data.Dataset):#继承Dataset
	def __init__(self):
		#一些由2维向量表示的数据集
		self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
		#这是数据集对应的标签
		self.Label=np.asarray([0,1,0,1,2])
	def __getitem__(self, index):
		#把numpy转换为Tensor
		txt=torch.from_numpy(self.Data[index])
		label=torch.tensor(self.Label[index])
		return txt,label
	def __len__(self):
		return len(self.Data)

3)获取数据集中数据

Test=TestDataset()
print(Test[2]) #相当于调用__getitem__(2)
print(Test.__len__())
#输出:
#(tensor([2, 1]), tensor(0))
#5

4) 批处理
以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。

test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
	print('i:',i)
	Data,Label=traindata
	print('data:',Data)
	print('Label:',Label)
#输出
i: 0
data: tensor([[1, 2],[3, 4]])
Label: tensor([0, 1])
i: 1
data: tensor([[2, 1],[3, 4]])
Label: tensor([0, 1])
i: 2
data: tensor([[4, 5]])
Label: tensor([2])

从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器

dataiter=iter(test_loader)
imgs,labels=next(dataiter)

2.3 下面介绍一下自定义数据集构成方法

类型1:map-style datasets

A:构建dateset类
重点是把 x 和 label 都分别装入两个列表 self.src 和 self.trg ,然后通过 getitem(self, index)返回对应元素。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
 
class My_dataset(Dataset):
    def __init__(self):
        super().__init__()
        # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
        # 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
        self.x = torch.randn(1000,3)
        self.y = self.x.sum(axis=1)
        self.src,  self.trg = [], []
        for i in range(1000):
            self.src.append(self.x[i])
            self.trg.append(self.y[i])
           
    def __getitem__(self, index):
        return self.src[index], self.trg[index]

    def __len__(self):
     # 或者return len(self.trg), src和trg长度一样
        return len(self.src) 
        
 
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)

# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
for i_batch, batch_data in enumerate(data_loader_train):
    print(i_batch)  # 打印batch编号
    print(batch_data[0])  # 打印该batch里面src
    print(batch_data[1])  # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
    print(i_batch)  # 打印batch编号
    print(src)  # 打印该batch里面src的尺寸
    print(trg)  # 打印该batch里面trg的尺寸    

输出

0
tensor([[ 0.2588, -0.0292,  1.0143],
    [ 0.1215, -0.0259, -1.1979],
    [ 0.2648,  1.7875,  0.3942],
    [-0.7355, -0.9454, -0.1084],
    [-0.1744,  0.1619,  0.5177]])
tensor([ 1.2439, -1.1023,  2.4465, -1.7893,  0.5051])
1
tensor([[ 0.6797, -0.3623, -0.2554],
    [-1.0481, -0.7783,  1.8088],
    [ 0.6535,  0.5184, -0.0382],
    [ 2.3790,  1.8096,  0.1110],
    [-0.3820, -1.5508,  0.3057]])
tensor([ 0.0619, -0.0176,  1.1337,  4.2997, -1.6271])
...

B:借助TensorDataset直接将数据包装成dataset类
另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
 
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
 
# 总共有9990个数据
data = TensorDataset(src, trg)
# 9990个数据被分成1998个batch,每batch有数据5个,所以data_loader的len为1998,可以从i_batch看出
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
    print(i_batch)  # 打印batch编号
    print(batch_data[0].size())  # 打印该batch里面src
    print(batch_data[1].size())  # 打印该batch里面trg

输出

0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
2
torch.Size([5])
torch.Size([5])
...

类型2:iterable-style datasets

可迭代样式的数据集是IterableDataset的一个实例,该实例必须重写__iter__方法,该方法用于对数据集进行迭代。这种类型的数据集特别适合随机读取数据不太可能实现的情况,并且批处理大小batchsize取决于获取的数据。比如读取数据库,远程服务器或者实时日志等数据的时候,可使用该样式,一般时序数据不使用这种样式。

注意:
一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,
因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就很不方便。不过,使用PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数

Reference:

https://blog.csdn.net/zuiyishihefang/article/details/105985760
https://blog.csdn.net/weixin_42468475/article/details/108714940
https://blog.csdn.net/u011995719/article/details/85102770

  • 8
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI扩展坞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值