pytorch设置batch

使用pytorch进行数据网络训练时,数据集可能有上万条数据,训练的话比较浪费时间,设置batch,一次训练一个batch_size的大小,既节省时间,又可以快速收敛。
使用前需要导入包:

from torch.utils.data import Dataset, DataLoader, TensorDataset

设置batch,需要将训练数据的输入属性和标签放入DataLoader中,见下:


def addbatch(data_train,data_test,batchsize):
    """
    设置batch
    :param data_train: 输入
    :param data_test: 标签
    :param batchsize: 一个batch大小
    :return: 设置好batch的数据集
    """
    data = TensorDataset(data_train,data_test)
    data_loader = DataLoader(data, batch_size=batchsize, shuffle=False)#shuffle是是否打乱数据集,可自行设置

    return data_loader

使用时调用即可:

#设置batch
    traindata=addbatch(traininput,trainlabel,1000)#1000为一个batch_size大小为1000,训练集为10000时一个epoch会训练10次。

进行神经网络训练用下面方法:

    for epoch in range(EPOCH):
        for step, data in enumerate(traindata):
            inputs, labels = data
            # 前向传播
            out = net(inputs)
            # 计算损失函数
            loss = loss_func(out, labels)
            # 清空上一轮的梯度
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

<编程路上>

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值