pytorch保存与加载神经网络方法

应用pytorch训练一个神经网络后,如何保存神经网络呢?
pytorch中有两种方法,第一种是将网络整体保存,第二种是保存神经网络的参数(推荐第二种!),这里就简单讲讲如何保存参数。

  • 保存
    训练完后加上如下代码
    torch.save(model.state_dict(),“model_params.pkl”)
    这样在运行结束后你会发现你此项目的文件夹里会多出来一个model_params.pkl文件,
    代码中的model是我代码里定义的神经网络的名字,如果你的网络名字叫net,那就写torch.save(net.state_dict(),“model_params.pkl”)。
  • 加载
    加载时运用下面一行代码
    model.load_state_dict(torch.load(‘model_params.pkl’))
    我是在另一个python file中加载的,所以前面还要加上你的神经网络是怎么定义的,所以我的加载代码是这样的:
class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(3, 10, 5, 1, 2)
       self.pool = nn.MaxPool2d(2, 2)
       self.conv2 = nn.Conv2d(10, 20, 5, 1, 2)
       self.fc1 = nn.Linear(20*56*56, 120)
       self.fc2 = nn.Linear(120, 84)
       self.fc3 = nn.Linear(84, 4)

   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = x.view(-1, 20*56*56)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x
       
model = Net()

model.load_state_dict(torch.load('model_params.pkl'))

补充一下保存整体网络的方法
保存:
torch.save(model, ‘model_params.pkl’)

加载:
由于之前保存的是网络的整体结构,所以在加载的程序中不需要class Net(nn.Module):{……}这一项,这里与保存参书的方法不同。
只需要一行代码
model = torch.load(‘model_params.pkl’)

以上是全部内容,希望对你有帮助哦!

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值