PyTorch学习笔记5——批训练

1、torch.utils.data.TensorDataset() 和torch.utils.data.DataLoader()
pytorch提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset和DataLoader,我们要自定义自己数据读取的方法,就需要继承torch.utils.data.Dataset,并将其封装到DataLoader中。

TensorDataset定义数据集用以包装数据和目标张量,便于传入DataLoader进行批训练,即将数据变成torch中DataLoader可使用的形式,继承了Dataset抽象类。

DataLoader是可迭代的数据装载器,可以迭代输出Dataset的内容,同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程。

2、epoch、batch、iteration
在这里插入图片描述
3、enumerate()
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

4、示例代码

import torch
import torch.utils.data as Data  # 进行批训练的模块
'''
P15:批训练
当数据非常大时。通过批训练将数据分小批次来训练,以提升神经网络的训练效率或速度
'''
BATCH_SIZE = 5  # 批大小

x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

torch_dataset = Data.TensorDataset(x,y)  # 定义包装数据和目标张量的数据库
# DataLoader为可迭代的数据装载器
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,  # 每个epoch是否打乱训练数据的顺序
    # num_workers=2,  # 是否多线程读取数据,win环境无需
)

for epoch in range(3):   # epoch:当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程为一次epoch
    for step,(batch_x,batch_y) in enumerate(loader):  # 每一次整批训练时,都将数据拆分小批次训练
        # training

        print('Epoch: ',epoch,'| Step: ',step,'| batch x: ',batch_x.numpy(),'| batch y: ',batch_y.numpy())

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值