4-Datasets和DataLoader

75 篇文章 2 订阅

1.说明

pytorch中用dataset来对单个样本进行打包features和labels,得到datasets=(features,labels);用DataLoader来包装Datasets,使得可以每一个批次batchsize打包起来,得到一个批量大小的datasets
在这里插入图片描述

  • torch.utils.data.Dataset:打包(features_i,labels_i)得到dataset
  • torch.utils.data.DataLoader:打包dataset_i 得到 DataLoader

2. 从pytorch中得到datasets

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

# 训练的datasets
training_data = datasets.FashionMNIST(
	# root:表示地址
    root="data",
    # True:表示是否是训练集
    train=True,
    download=True,
    # transform:表示的是将图片转换张量,这里可以对图片进行相关预处理
    transform=ToTensor()
)

# 测试的datasets
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

3. 自定义datasets

自定义dataset需要满足4个条件

  • 继承自官方的datasets类
class CustomImageDataset(Dataset):
  • 覆写初始化函数 __init__
    当实例化Dataset对象时,__init__函数运行一次。我们初始化包含图像、注释文件和两个转换的目录
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

(1)annotations_file:表示的是labels的csv格式文件名
(2)img_dir:表示的是features图片目录
(3)transform:表示的是对图片features进行预处理
(4)target_transform:表示的是对标签labels进行预处理

  • 覆写长度__len__函数
    __len__函数返回数据集中的样本数量
def __len__(self):
    return len(self.img_labels)
  • 覆写 __getitem__函数
    作用:根据给定的index序号从datasets中返回特征features和标签
def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label
  • img_path:图片的位置
  • label:根据idx得到对应的标签label
  • self.transform:将对应的图片进行转换
  • self.target_transform:将标签转换预处理

4. 定义DataLoader

我们定义好一个样本的datasets,我们需要将多个datasets转换成一个批量大小的DataLoader;我们经常需要将一个批次里面的datasets进行打乱处理再训练,所以我们需要定义DataLoader;

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  • training_data: 训练集的datasets
  • test_data:测试集的datasets
  • batch_size:批量大小,需要多少个datasets
  • shuffle:一个批量大小中的datasets是否需要打乱,为了提高模型的鲁棒性

4.1 DataLoader

在文件dataloader.py中有定义

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)
  • dataset:传入的datasets
  • batch_size: 定义批量大小,默认为1
  • shuffle:是否打乱一个批量大小里面的datasets
  • sampler: 采样
  • batch_sampler:批量采样
  • num_workers:为数据使用多少子流程;可以给程序启动多进程处理
  • collate_fn: 用于如何取样本,可以自己定义如何对样本的取出处理
  • pin_memory:将数据固定到GPU上
  • drop_last:如果为真,删除最后一个不完整的batch
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
datasets 和 dataloaderPyTorch 中用于处理和加载数据的两个重要模块。 在 PyTorch 中,datasets 用于存储和处理数据集,例如图像、文本等。PyTorch 提供了许多内置的 datasets,如 torchvision 中的 ImageFolder 和 MNIST,也可以自定义 datasets。 下面是使用 torchvision.datasets.ImageFolder 加载图像数据集的示例代码: ```python import torchvision.datasets as datasets # 定义数据集路径 data_dir = 'path/to/dataset' # 创建 ImageFolder 数据集 dataset = datasets.ImageFolder(data_dir) # 获取数据集的长度 dataset_size = len(dataset) # 获取类别标签 class_labels = dataset.classes # 可以通过索引访问数据集中的样本 sample, label = dataset[0] # 可以通过迭代器遍历整个数据集 for sample, label in dataset: # 在这里对样本进行处理/转换 pass ``` 接下来,我们可以使用 dataloader 对数据集进行批量加载和预处理。dataloader 可以方便地将数据集划分为小批量样本,进行数据增强或标准化等操作。 下面是使用 torch.utils.data.DataLoader 对数据集进行批量加载的示例代码: ```python import torch.utils.data as data # 定义批量大小和多线程加载数据的工作进程数 batch_size = 32 num_workers = 4 # 创建 dataloader dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) # 可以通过迭代器遍历整个数据集的小批量样本 for batch_samples, batch_labels in dataloader: # 在这里对小批量样本进行处理/转换 pass ``` 在上面的示例中,我们创建了一个 dataloader,并指定了批量大小和加载数据的工作进程数。`shuffle=True` 表示每个 epoch 都会对数据进行随机打乱,以增加数据的多样性。 通过使用 datasets 和 dataloader,我们可以方便地加载和处理各种类型的数据集,并应用各种预处理操作。这些模块的使用可以大大简化数据加载和处理的过程,提高代码的可读性和可维护性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值