使用pytorch进行数据网络训练时,数据集可能有上万条数据,训练的话比较浪费时间,设置batch,一次训练一个batch_size的大小,既节省时间,又可以快速收敛。
使用前需要导入包:
from torch.utils.data import Dataset, DataLoader, TensorDataset
设置batch,需要将训练数据的输入属性和标签放入DataLoader中,见下:
def addbatch(data_train,data_test,batchsize):
"""
设置batch
:param data_train: 输入
:param data_test: 标签
:param batchsize: 一个batch大小
:return: 设置好batch的数据集
"""
data = TensorDataset(data_train,data_test)
data_loader = DataLoader(data, batch_size=batchsize, shuffle=False)#shuffle是是否打乱数据集,可自行设置
return data_loader
使用时调用即可:
#设置batch
traindata=addbatch(traininput,trainlabel,1000)#1000为一个batch_size大小为1000,训练集为10000时一个epoch会训练10次。
进行神经网络训练用下面方法:
for epoch in range(EPOCH):
for step, data in enumerate(traindata):
inputs, labels = data
# 前向传播
out = net(inputs)
# 计算损失函数
loss = loss_func(out, labels)
# 清空上一轮的梯度
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。