DataLoader学习笔记

原文连接:PyTorch DataLoader: A Complete Guide • datagy

1、理解dataloader类

# Understanding the PyTorch DataLoader Class

from torch.utils.data import DataLoader

DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None, 
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=None, 
    pin_memory=False, 
    drop_last=False, 
    timeout=0, 
    worker_init_fn=None, 
    multiprocessing_context=None, 
    generator=None, 
    *, 
    prefetch_factor=2, 
    persistent_workers=False
)

2、创建和使用PyTorch dataloader类

# Loading the MNIST Dataset Using PyTorch
# Importing Libraries
from torchvision.datasets import MNIST

# Downloading and Saving MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())

# Accessing a Dataset Item
print(data_train[0])

# Returns:
# (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
#           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)

可视化一个例子:

# Visualizing a Sample
import matplotlib.pyplot as plt
plt.imshow(data_train.data[0])
plt.show()

输出样例如下:

加载数据集,然后创建自己的dataloader

# Creating a Training DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# Downloading and Saving MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())

# Creating Data Loader
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)

print(data_loader)

# Returns:
# <torch.utils.data.dataloader.DataLoader object at 0x7fc3c021b6d0>

3、在PyTorch dataloader上迭代

# Loading the First Batch and Printing Information
for idx, batch in enumerate(data_loader):
    print('Batch index: ', idx)
    print('Batch size: ', batch[0].size())
    print('Batch label: ', batch[1])
    break

# Returns:
# Batch index:  0
# Batch size:  torch.Size([20, 1, 28, 28])
# Batch label:  tensor([3, 3, 7, 7, 2, 4, 7, 2, 1, 8, 3, 3, 9, 3, 2, 3, 5, 0, 6, 8])

4、在PyTorch dataloader上加载数据和目标

# Accessing Data and Targets in a PyTorch DataLoader
for idx, (data, target) in enumerate(data_loader):
    print(data[0])
    print(target[0])
    break

# Returns:
# tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
#           0.0000, 0.0000, 0.0000, 0.0000]]])
# tensor(1)

5、用PyTorch dataloader加载数据到GPU(CUDA)

# Loading Data to a GPU with a PyTorch DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch

data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for idx, (data, target) in enumerate(data_loader):
    data = data.to(device)
    target = target.to(device)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
DataLoader()是PyTorch中用于加载数据集的函数。在使用DataLoader()函数时,首先需要创建一个数据集对象,如train_data = trainset(),然后将该数据集对象作为参数传入DataLoader()中,同时可以设置一些参数,如batch_size(每个批次的样本数量)和shuffle(是否打乱数据集)。最后,通过训练循环可以迭代地从DataLoader对象中获取每个批次的数据样本。 根据引用中的博文内容,DataLoader的主要使用包括以下几个方面: 1. DataLoader的基础使用:创建数据集对象和DataLoader对象,并设置一些参数。 2. 数据集和DataLoader的区别:数据集对象是存储和处理数据的对象,而DataLoader对象是用于加载数据集的对象。 3. 数据集的处理:可以根据实际需求自定义数据集对象,如MyDataset(),并将该数据集对象传入DataLoader中进行加载。 4. 在训练循环中使用DataLoader:通过迭代DataLoader对象即可逐批次获取数据样本进行训练。 综上所述,DataLoader()函数是PyTorch中用于加载数据集的函数,可以方便地进行数据的批量加载和处理。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [AGPCNet——dataloader()函数](https://blog.csdn.net/python_Ezreal/article/details/125094746)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [PyTorch学习笔记(4)--DataLoader的使用](https://blog.csdn.net/weixin_43981621/article/details/119685671)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值