pytorch:批训练(batch_training)

一、什么是批训练?

批训练: 意思是把整套训练数据分成数个批次进行训练,每个批次从数据中选取 n_num(总数据)/n_batch(批次) 个数据,直到把整套数据训练完成。
举个例子,有10个输入训练数据,每个批次训练3个数据,这训练完成需要4个批次,第一次从[1,2,3,4,5,6,7,8,9,0]随机选取3个数据,batch_1=[2,5,9],第二次再从剩余的7个数据中选取,得到batch_2=[1,4,8],第三次又从剩余的4个数据中选取3个,得到batch_3=[4,6,0],数据中只剩余1个数据,那么batch_4=[7]。直到把整套数据都作为数据训练完成为止。

二、批训练有什么好处?

1、内存利用率提高了,大矩阵乘法的并行化效率提高。
2、跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。
3、在一定范围内,一般来说 Batch_Size 越大,其确定的下降方向越准,引起训练震荡越小。

三、程序:

import torch
import torch.utils.data as data

# 创建数据
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

# 先转换成torch能识别的dataset
torch_dataset = data.TensorDataset(x, y)

# 把dataset放入DataLoader
loader = data.DataLoader(
    dataset=torch_dataset,
    batch_size=4,             # 每批提取的数量
    shuffle=True,             # 要不要打乱数据(打乱比较好)
    num_workers=2             # 多少线程来读取数据
)

if __name__ == '__main__':
    for epoch in range(3):    # 对整套数据训练3次
        for step, (batch_x, batch_y) in enumerate(loader):  # 每一步loader释放一小批数据用来学习
            # 训练过程

            # 打印数据
            print("epoch:", epoch, "step:", step, 'batch_x:', batch_x.numpy(), 'batch_y:', batch_y.numpy())

4、结果:

在这里插入图片描述

©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页