Torch_3_Dataset与Dataloader

本文介绍了PyTorch中用于训练模型的数据加载方式,包括自定义`Dataset`子类来处理数据集,以及使用`DataLoader`进行批量数据加载。示例展示了如何从numpy文件读取数据,并在训练过程中通过迭代`DataLoader`获取批次数据。这种方式统一了数据输入,提高了代码可读性和易用性。
摘要由CSDN通过智能技术生成

torch中的data迭代方法

介绍


  • 看代码的过程中不难发现,不同作者模型训练时的数据输入方法差别非常大。
  • torch提供了统一的接口,通过迭代器实现数据和标签的读取,使用方便也利于阅读。


实现方法


  • 导入

    from torch.utils.data import Dataset, DataLoader
    
  • Dataset

    • torch内置抽象类,无法实例化,通过继承并重写魔术方法实现
    class MyDataset(Dataset):
        def __init__(self, filepath):
            xy = np.load(filepath)
            self.len = xy.shape[0]
            self.x_data = torch.from_numpy(xy[:, :-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
    
        def __getitem__(self, item):
            return self.x_data[item], self.y_data[item]
    
        def __len__(self):
            return self.len
    
    dataset = MyDataset('MyData.npy')
    
    • 示例中,以读取numpy文件为例,通过重写__getitem____len__方法,实现数据的随机读取

  • Dataloader

    • 调用dataset 实例,通过设定的参数可生成DataLoader
    train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
    

  • 训练中调用数据

    for i, data in enumerate(train_loader, 0):  #
    		x, y = data
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值