利用torch.utils.DataLoader进行批训练

先贴代码:

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)  # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)  # this is y data (torch tensor)

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)

if __name__ == "__main__":
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x, '| batch y: ', batch_y)

运行结果为:

Epoch:  0 | Step:  0 | batch x:  tensor([10.,  7.,  2.,  4.,  9.]) | batch y:  tensor([1., 4., 9., 7., 2.])
Epoch:  0 | Step:  1 | batch x:  tensor([6., 1., 8., 3., 5.]) | batch y:  tensor([ 5., 10.,  3.,  8.,  6.])
Epoch:  1 | Step:  0 | batch x:  tensor([2., 1., 7., 8., 6.]) | batch y:  tensor([ 9., 10.,  4.,  3.,  5.])
Epoch:  1 | Step:  1 | batch x:  tensor([ 3., 10.,  9.,  5.,  4.]) | batch y:  tensor([8., 1., 2., 6., 7.])
Epoch:  2 | Step:  0 | batch x:  tensor([ 5.,  1., 10.,  6.,  8.]) | batch y:  tensor([ 6., 10.,  1.,  5.,  3.])
Epoch:  2 | Step:  1 | batch x:  tensor([3., 9., 4., 7., 2.]) | batch y:  tensor([8., 2., 7., 4., 9.])

先说此段代码发现的一个小细节,就是在DataLoader中设置了num_workers,此代码的目的是使用cpu中的多线程机制来实现代码快速的快速运行,但是在运行过程中,只有放在if __name__ = "__main__"(主函数中)才不会出现报错。
PyTorch 数据加载实用程序torch.utils.data中的TensorDataset类的作用是使用数据集来包装张量,分别将feature和label传入即可。
torch.utils.data的核心是Dataloader类,DataLoader为数据加载器,将数据集和样本集组合在一起。DataLoader支持单进程(线程)或者多进程(线程)来加载数据集并且可以实现批训练来提高训练效率。上面的代码中包含了,dataset = torch_dataset使用我们打包好的数据集,batch_size=5使用了按批传入的功能,shuffle在传入的时候打乱了数据集,num_workers=2使用了两个线程,提高运行速度和系统效率。

具体内容可见pytorch文档

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值