Pytorch的数据加载

PyTorch将数据集的处理过程标准化,提供了Dataset基本的数据 类,并在torchvision中提供了众多数据变换函数,数据加载的具体过程 主要分为3步:
1.继承Dataset类
对于数据集的处理,PyTorch提供了torch.utils.data.Dataset这个抽象 类,在使用时只需要继承该类,并重写__len__()和__getitem()__函数, 即可以方便地进行数据集的迭代。

from torch.utils.data import Dataset
class my_data(Dataset):
     def _init_(self,image_path,annotation_path,transform-None):
     #初始化,读取数据集
 def _len_(self):
      #获取数据集的总大小
  def _getitem_(self,id):
      #对于指定的id,读取数据并返回

对上述初始化的·实际使用:

dataset = my_data("your image path", "your annotation path")
 # 实例化该类 for data in dataset:
     print(data)

2.数据变换与增强:torchvision.transforms
第一步虽然将数据集加载到了实例中,但在实际应用时,数据集中 的图片有可能存在大小不一的情况,并且原始图片像素RGB值较大 (0~255),这些都不利于神经网络的训练收敛,因此还需要进行一些 图像变换工作。PyTorch为此提供了torchvision.transforms工具包,可以 方便地进行图像缩放、裁剪、随机翻转、填充及张量的归一化等操作, 操作对象是PIL的Image或者Tensor。
如果需要进行多个变换功能,可以利用transforms.Compose将多个 变换整合起来,并且在实际使用时,通常会将变换操作集成到Dataset的 继承类中。具体示例如下:

from torchvision import transforms
#将transform集成到dataset类中,使用compose将多个变换整合到一起
dataset = my_data("your image path","your annotation path",transforms=transforms.Compose([
transforms.Resize(256) #将图像最短边缩小至256,宽高比例不变
#以0.5的概率随即翻转指定的PIL图像
transforms.RandomHorizaontalFlip()
#将PIL图像转为Tensor,元素区间从[0,255]归一化到[0,1]
transforms.ToTensor()
#进行mean与std为0.5的标准化
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
]))

3.继承dataloader
经过前两步已经可以获取每一个变换后的样本,但是仍然无法进行 批量处理、随机选取等操作,因此还需要torch.utils.data.Dataloader类进 一步进行封装,使用方法如下例所示,该类需要4个参数,第1个参数是 之前继承了Dataset的实例,第2个参数是批量batch的大小,第3个参数是 是否打乱数据参数,第4个参数是使用几个线程来加载数据。

from torch.utils.data import Dataloader
# 使用Dataloader进一步封装Dataset 
dataloader = Dataloader(dataset, batch_size=4,                                 shuffle=True, num_workers=4)

dataloader是一个可迭代对象,对该实例进行迭代即可用于训练过程。

data_iter = iter(dataloader)
for step in range(iters_per_epoch):
 data = next(data_iter)        # 将data用于训练网络即可
PyTorch是一个基于Python的科学计算包,其主要功能是进行张量计算和深度学习模型构建。在深度学习中,数据加载是一个重要的环节,PyTorch提供了一些工具和函数来简化数据加载的过程。 PyTorch数据加载主要涉及到两个类:`torch.utils.data.Dataset`和`torch.utils.data.DataLoader`。其中,`Dataset`类用于表示数据集,而`DataLoader`类则用于对数据集进行加载和处理。 使用PyTorch进行数据加载的基本步骤如下: 1. 定义数据集:需要继承`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。其中,`__len__`方法返回数据集的大小,`__getitem__`方法用于获取指定索引的数据。 2. 创建数据集实例:将定义好的数据集实例化,并传入相应的参数(如文件路径等)。 3. 创建数据加载器:使用`torch.utils.data.DataLoader`类创建数据加载器,可以指定批次大小、是否打乱数据、多进程等参数。 4. 迭代数据:使用for循环迭代数据加载器,每次迭代返回一个批次的数据。 下面是一个简单的示例代码,用于加载MNIST数据集: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms # 定义自己的数据集类 class MyDataset(Dataset): def __init__(self, path): self.data = torch.load(path) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def __len__(self): return len(self.data) def __getitem__(self, index): x, y = self.data[index] x = self.transform(x) return x, y # 创建数据集实例 train_dataset = MyDataset('mnist/train.pt') test_dataset = MyDataset('mnist/test.pt') # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) # 迭代数据 for batch_idx, (data, target) in enumerate(train_loader): # 对批次数据进行训练或测试 ... ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值