pytorch中Dataloader的使用

1.大致流程

pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环dataloader对象,将data,label拿到模型中去训练

2.Dataset

你需要自己定义一个class,里面至少包含3个函数:

(特别要注意的是输入进函数的数据一定得是可迭代的。如果是自定的数据集的话可以在定义类中用def__len__、def__getitem__定义。)
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__:返回一条训练数据,并将其转换成tensor

3.Dataloader

参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:使用这个参数可以自己操作每个batch的数据

(collate_fn暂时用不到,可以参考Pytorch中DataLoader的使用_kahuifu的博客-CSDN博客_dataloader

4.按照batch取数据和标签

5.代码

import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np

class Mydata(Dataset):
    def __init__(self, train_x, train_label):
        self.train_x = train_x
        self.train_label = train_label

    def __getitem__(self, item):
        assert item<len(self.train_x)
        return self.train_x[item],self.train_label[item]

    def __len__(self):
        return len(self.train_x)

train_x = np.zeros((4,3))
train_label = np.arange(4).reshape((-1,1))
# print(train_label)
dataset = Mydata(train_x,train_label)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

for i,data in enumerate(dataloader):
    print(i,data[:-1])
    print(data[-1])

执行结果

(注:如果在定义Dataset类的时候,在方法__getitem__中不加入return self.train_label[item]这一条命令的话,最终在按照batch取数据的时候,不会取label,只会取训练数据)

0 [tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)]
tensor([[1],
        [0]], dtype=torch.int32)
1 [tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)]
tensor([[3],
        [2]], dtype=torch.int32)

Process finished with exit code 0
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值