1.大致流程
pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环dataloader对象,将data,label拿到模型中去训练
2.Dataset
你需要自己定义一个class,里面至少包含3个函数:
(特别要注意的是输入进函数的数据一定得是可迭代的。如果是自定的数据集的话可以在定义类中用def__len__、def__getitem__定义。)
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__:返回一条训练数据,并将其转换成tensor
3.Dataloader
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:使用这个参数可以自己操作每个batch的数据
(collate_fn暂时用不到,可以参考Pytorch中DataLoader的使用_kahuifu的博客-CSDN博客_dataloader)
4.按照batch取数据和标签
5.代码
import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np
class Mydata(Dataset):
def __init__(self, train_x, train_label):
self.train_x = train_x
self.train_label = train_label
def __getitem__(self, item):
assert item<len(self.train_x)
return self.train_x[item],self.train_label[item]
def __len__(self):
return len(self.train_x)
train_x = np.zeros((4,3))
train_label = np.arange(4).reshape((-1,1))
# print(train_label)
dataset = Mydata(train_x,train_label)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for i,data in enumerate(dataloader):
print(i,data[:-1])
print(data[-1])
执行结果
(注:如果在定义Dataset类的时候,在方法__getitem__中不加入return self.train_label[item]这一条命令的话,最终在按照batch取数据的时候,不会取label,只会取训练数据)
0 [tensor([[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float64)]
tensor([[1],
[0]], dtype=torch.int32)
1 [tensor([[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float64)]
tensor([[3],
[2]], dtype=torch.int32)
Process finished with exit code 0