Dataloader是 torch 给我们用来包装数据的工具。所以我们要将自己的 (ndarray或其他) 数据形式装换成 Tensor, 然后再放进Dataloader这个包装器中。 使用Dataloader有什么好处呢? 就是它可以帮我们有效地迭代数据。
1 准备部分
1.1 导入库
import torch
import torch.utils.data as Data
1.2 数据集部分
x=torch.linspace(1,10,10)
y=torch.linspace(-10,-1,10)
print(x,'\n',y)
'''
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
tensor([-10., -9., -8., -7., -6., -5., -4., -3., -2., -1.])
'''
BATCH_SIZE=5
2 方法1 TensorDataset+DataLoader
torch_dataset=Data.TensorDataset(x,y)
#先转化成pytorch可以识别的Dataset格式
loader=Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True)
#把dataset导入dataloader,并设置batch_size和shuffle
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print('epoch: ',epoch)
print('step: ',step,'\n x: ',batch_x,'\n y: ',batch_y)
print('*'*30)
# 注:也可以直接:
#for batch_x,batch_y in loader:
'''
epoch: 0
step: 0
x: tensor([9., 2., 6., 5., 3.])
y: tensor([-2., -9., -5., -6., -8.])
******************************
epoch: 0
step: 1
x: tensor([10., 1., 7., 4., 8.])
y: tensor([ -1., -10., -4., -7., -3.])
******************************
epoch: 1
step: 0
x: tensor([ 3., 5., 2., 10., 4.])
y: tensor([-8., -6., -9., -1., -7.])
******************************
epoch: 1
step: 1
x: tensor([7., 6., 8., 1., 9.])
y: tensor([ -4., -5., -3., -10., -2.])
******************************
epoch: 2
step: 0
x: tensor([10., 3., 1., 8., 9.])
y: tensor([ -1., -8., -10., -3., -2.])
******************************
epoch: 2
step: 1
x: tensor([5., 7., 2., 6., 4.])
y: tensor([-6., -4., -9., -5., -7.])
******************************
'''
3 方法2 自定义类+DataLoader
class MyDataSet(Data.Dataset):
def __init__(self,x,y):
super(MyDataSet,self).__init__()
self.x=x
self.y=y
def __len__(self):
return self.x.shape[0]
#有几组数据
def __getitem__(self,idx):
return(self.x[idx],self.y[idx])
#根据索引找到数据
loader2=Data.DataLoader(
MyDataSet(x,y),
batch_size=BATCH_SIZE,
shuffle=True)
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader2):
print('epoch: ',epoch)
print('step: ',step,'\n x: ',batch_x,'\n y: ',batch_y)
print('*'*30)
'''
epoch: 0
step: 0
x: tensor([9., 7., 2., 6., 3.])
y: tensor([-2., -4., -9., -5., -8.])
******************************
<class 'torch.Tensor'>
epoch: 0
step: 1
x: tensor([ 5., 4., 8., 10., 1.])
y: tensor([ -6., -7., -3., -1., -10.])
******************************
<class 'torch.Tensor'>
epoch: 1
step: 0
x: tensor([ 6., 3., 5., 10., 9.])
y: tensor([-5., -8., -6., -1., -2.])
******************************
<class 'torch.Tensor'>
epoch: 1
step: 1
x: tensor([4., 7., 1., 8., 2.])
y: tensor([ -7., -4., -10., -3., -9.])
******************************
<class 'torch.Tensor'>
epoch: 2
step: 0
x: tensor([ 6., 1., 10., 3., 4.])
y: tensor([ -5., -10., -1., -8., -7.])
******************************
<class 'torch.Tensor'>
epoch: 2
step: 1
x: tensor([2., 9., 7., 5., 8.])
y: tensor([-9., -2., -4., -6., -3.])
******************************
'''
4 collate_fn
collate_fn
是一个参数,通常在 PyTorch 中的DataLoader
类中使用,它允许用户指定一个函数来决定如何将多个数据样本合并成一个批次- 主要作用——自定义数据批次的创建:
- 但如果数据样本的大小或形式不一,比如列表长度不一或字典结构不同,则需要自定义方法来正确合并这些数据。
- 默认情况下,
DataLoader
简单地将这些样本堆叠在一起,这适用于多数情况,特别是当所有数据样本形状相同时。 - 在加载数据时,
DataLoader
需要将多个数据样本(通常来自Dataset
对象的__getitem__
方法)组合成数据批次
from torch.utils.data import DataLoader
def my_collate_fn(batch):
# `batch` 是一个列表,其中包含了从 `Dataset.__getitem__` 返回的数据样本
# 在这里实现将这些样本合并为一个批次的逻辑
...
return batch
dataloader = DataLoader(my_dataset, batch_size=4, collate_fn=my_collate_fn)