PyTorch DataLoader

在使用PyTorch时,经常需要自己写DataLoader,写DataLoader前需要先写Dataset,通常情况可能需要给定一个file_path动态加载,这里先定义一个X和Y。

Dataset类需要实现__getitem__和__len__两个函数


from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.utils.data as Data
import torch


class MyDataSet(Dataset):
    def __init__(self, file_path):
        super().__init__()
        self.X = torch.randn(20)
        self.Y = torch.linspace(1, 10, 20)
        self._len = len(self.X)
        pass
    def __getitem__(self, idx):
        return (self.X[idx], self.Y[idx])
    
    def __len__(self):
        return self._len

我们测试下刚刚的Dataset并建立DataLoader,DataLoader的好处是可以一次生成batch个数据,方便训练和计算。

data_set = MyDataSet('')

data_loder = DataLoader(dataset=data_set, batch_size=6, shuffle=True)

for batch in data_loader:
    print(batch)
会得到类似这样的输出:

----------------------
[tensor([6., 4., 3.]), tensor([5., 7., 8.])]
[tensor([ 2.,  5., 10.]), tensor([9., 6., 1.])]
[tensor([7., 1., 9.]), tensor([ 4., 10.,  2.])]
[tensor([8.]), tensor([3.])]

-------------------------------------

[tensor([7., 3., 9.]), tensor([4., 8., 2.])]
[tensor([4., 6., 1.]), tensor([ 7.,  5., 10.])]
[tensor([ 5.,  8., 10.]), tensor([6., 3., 1.])]
[tensor([2.]), tensor([9.])]

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值