7 pytorch DataLoader, TensorDataset批数据训练方法

前言

本文主要介绍pytorch里面批数据的处理方法,以及这个算法的效果是什么样的。具体就是要弄明白这个批数据选取的算法是在干什么,不会涉及到网络的训练。

from torch.utils.data import DataLoader, TensorDataset

主要实现就是上面的数据集和数据载入两个类来实现该算法功能,这里只要求会调用接口就够了。

一、生成数据集

import torch
from torch.utils.data import DataLoader, TensorDataset
# 准备数据集与定义batch_size
batch_size = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
print(x)
print(y)

输出:
在这里插入图片描述

二、将训练数据进行batch处理

# 将训练数据放入torch的数据集
train_dataset = TensorDataset(x, y)
# 载入batch批次选取数据规则
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True,   # True表示每一个epoch都打乱抽取
                          num_workers=2   # 定义工作线程个数
                          )

三、epoch训练

# 训练模型
epochs = 3
for epoch in range(epochs):
    # 每一个epoch表示将整个数据集所有数据都训练一遍
    for step,(batch_x, batch_y) in enumerate(train_loader):
        # training......
        # 这里用enumerate是为了让你更加情况观察,batch的逻辑是怎么样的
        # 实际中只要  for batch_x,batch_y in train_loader就可以了
        print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.data.numpy(),'| batch y:',batch_y.data.numpy())
# 测试模型(略)

输出:
在这里插入图片描述
【注】:可以看到每一个epoch将所有样本点都涉及到了一次,并且还是打乱顺序了的。
下面看看将shuffle=False不打乱顺序会发生什么:
在这里插入图片描述
【注】:可以看到每一个epoch,都是相同的结果,可想而知这样训练效果肯定没有打乱的好。
注意到,上半batch=5,恰好将样本总数10均分为2分,那么要是不能均分会发生什么,下面将batch=8,看看会发生什么。
在这里插入图片描述
可以看到直接将不够的组就直接剩下的了。

总结

后面我们会经常用到这种batch和epoch的训练方法。

  • 7
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值