pytorch笔记:Dataloader

Dataloader是 torch 给我们用来包装数据的工具。所以我们要将自己的 (ndarray或其他) 数据形式装换成 Tensor, 然后再放进Dataloader这个包装器中。 使用Dataloader有什么好处呢? 就是它可以帮我们有效地迭代数据。

1 准备部分

1.1 导入库

import torch
import torch.utils.data as Data

1.2 数据集部分

x=torch.linspace(1,10,10)
y=torch.linspace(-10,-1,10)
print(x,'\n',y)
'''
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]) 
tensor([-10.,  -9.,  -8.,  -7.,  -6.,  -5.,  -4.,  -3.,  -2.,  -1.])
'''
BATCH_SIZE=5

2 方法1 TensorDataset+DataLoader

torch_dataset=Data.TensorDataset(x,y)
#先转化成pytorch可以识别的Dataset格式

loader=Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True)
#把dataset导入dataloader,并设置batch_size和shuffle

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        print('epoch: ',epoch)
        print('step: ',step,'\n x: ',batch_x,'\n y: ',batch_y)
        print('*'*30)

# 注:也可以直接:
#for batch_x,batch_y in loader:
'''
epoch:  0
step:  0 
 x:  tensor([9., 2., 6., 5., 3.]) 
 y:  tensor([-2., -9., -5., -6., -8.])
******************************
epoch:  0
step:  1 
 x:  tensor([10.,  1.,  7.,  4.,  8.]) 
 y:  tensor([ -1., -10.,  -4.,  -7.,  -3.])
******************************
epoch:  1
step:  0 
 x:  tensor([ 3.,  5.,  2., 10.,  4.]) 
 y:  tensor([-8., -6., -9., -1., -7.])
******************************
epoch:  1
step:  1 
 x:  tensor([7., 6., 8., 1., 9.]) 
 y:  tensor([ -4.,  -5.,  -3., -10.,  -2.])
******************************
epoch:  2
step:  0 
 x:  tensor([10.,  3.,  1.,  8.,  9.]) 
 y:  tensor([ -1.,  -8., -10.,  -3.,  -2.])
******************************
epoch:  2
step:  1 
 x:  tensor([5., 7., 2., 6., 4.]) 
 y:  tensor([-6., -4., -9., -5., -7.])
******************************
'''

3 方法2 自定义类+DataLoader

class MyDataSet(Data.Dataset):
    def __init__(self,x,y):
        super(MyDataSet,self).__init__()
        self.x=x
        self.y=y
    
    def __len__(self):
        return self.x.shape[0]
    #有几组数据
    
    def __getitem__(self,idx):
        return(self.x[idx],self.y[idx])
    #根据索引找到数据

loader2=Data.DataLoader(
    MyDataSet(x,y),
    batch_size=BATCH_SIZE,
    shuffle=True)

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader2):
        print('epoch: ',epoch)
        print('step: ',step,'\n x: ',batch_x,'\n y: ',batch_y)
        print('*'*30)


'''
epoch:  0
step:  0 
 x:  tensor([9., 7., 2., 6., 3.]) 
 y:  tensor([-2., -4., -9., -5., -8.])
******************************
<class 'torch.Tensor'>
epoch:  0
step:  1 
 x:  tensor([ 5.,  4.,  8., 10.,  1.]) 
 y:  tensor([ -6.,  -7.,  -3.,  -1., -10.])
******************************
<class 'torch.Tensor'>
epoch:  1
step:  0 
 x:  tensor([ 6.,  3.,  5., 10.,  9.]) 
 y:  tensor([-5., -8., -6., -1., -2.])
******************************
<class 'torch.Tensor'>
epoch:  1
step:  1 
 x:  tensor([4., 7., 1., 8., 2.]) 
 y:  tensor([ -7.,  -4., -10.,  -3.,  -9.])
******************************
<class 'torch.Tensor'>
epoch:  2
step:  0 
 x:  tensor([ 6.,  1., 10.,  3.,  4.]) 
 y:  tensor([ -5., -10.,  -1.,  -8.,  -7.])
******************************
<class 'torch.Tensor'>
epoch:  2
step:  1 
 x:  tensor([2., 9., 7., 5., 8.]) 
 y:  tensor([-9., -2., -4., -6., -3.])
******************************
'''

4 collate_fn

  • collate_fn 是一个参数,通常在 PyTorch 中的 DataLoader 类中使用,它允许用户指定一个函数来决定如何将多个数据样本合并成一个批次
  • 主要作用——自定义数据批次的创建
    • 但如果数据样本的大小或形式不一,比如列表长度不一或字典结构不同,则需要自定义方法来正确合并这些数据。
    • 默认情况下,DataLoader 简单地将这些样本堆叠在一起,这适用于多数情况,特别是当所有数据样本形状相同时。
    • 在加载数据时,DataLoader 需要将多个数据样本(通常来自 Dataset 对象的 __getitem__ 方法)组合成数据批次
from torch.utils.data import DataLoader

def my_collate_fn(batch):
    # `batch` 是一个列表,其中包含了从 `Dataset.__getitem__` 返回的数据样本
    # 在这里实现将这些样本合并为一个批次的逻辑
    ...
    return batch

dataloader = DataLoader(my_dataset, batch_size=4, collate_fn=my_collate_fn)

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值