简单回归问题
以简单回归问题为例,实现神经网络的小批量训练、网络参数保存以及参数提取。
简单回归问题的神经网络实现可见:【简单回归问题的神经网络实现-pytorch】
dataLoader定义
dataLoader是torch提供用于封装数据的工具,可以有效实现网络训练过程中的批量训练问题。
#生成DataLoader数据结构
def dataLoader(x,y):
#将torch转换为Dataset
torch_dataset = Data.TensorDataset(x, y)
#将dataset放入DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=20, #最小训练批量
shuffle=True, #是否对数据进行随机打乱
num_workers=2, #多线程来读数据
)
return loader
使用方法如下:
#模拟数据
x,y=dataSet()
loader=dataLoader(x,y)
#迭代训练
for epoch in range(40):
lossAll=0
for step, (batch_x, batch_y) in enumerate(loader):
#预测
prediction=net(batch_x)
#计算误差
loss=loss_fun(prediction,batch_y)
lossAll+=loss.data.numpy()
#梯度降为0
optimizer.zero_grad()
#反向传递
loss.backward()
#优化梯度
optimizer.step()
#打印误差
print('Epoch: ', epoch, '| Step: ', step, '| loss: ',loss.data.numpy())
迭代过程如下:
参数保存
#保存网络
def saveNet(net,params=False):
if params:
#保存网络参数
torch.save(net.state_dict(),'net_params.pkl')
else:
#保存整个网络
torch.save(net,'net.pkl')
参数提取
#提取网络
def restoreNet(params=False):
if params:
#提取网络参数->注意需要新建一个相同类型的网络
net1=Net(1,[10,20],1)
net1.load_state_dict(torch.load('net_params.pkl'))
else:
#保存整个网络
net1 = torch.load('net.pkl')
====================================
今天到此为止,后续记录其他神经网络技术的学习过程。
以上学习笔记,如有侵犯,请立即联系并删除!