Pytorch自己的网络模型的保存和加载

1 自己写的网络模型的保存(两种方式)

1.1 第一种方式(保存整个网络结构+网络模型参数)

 torch.save(net, 'net.pth')

1.2 第二种方式(只保存网络模型参数)

这种方式是官方推荐的,因为它占的内存比第一种方式小,但是也不会小很多。但是我不推荐使用,因为使用起来比第一种要麻烦很多。

torch.save(net.state_dict(), 'net_params.pth')

2 自己写的网络模型的加载

首选需要说明的一点是,不管上述的那种方式,在我们加载网络模型的时候都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module

2.1 第一种加载方式

model = torch.load('net.pth') 

实例代码

#加载整个网络
 
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)
 
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
 
net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
net.eval() #加上这句后效果更好

2.2 第二种加载方式

model_object.load_state_dict(torch.load('net_params.pth')) 

实例代码

#只加载网络参数
 
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)
 
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
 
net = AlexNet()#只加载网络参数的时候需要自行实例化网络
net.cuda()#并设置网络运行在cpu还是gpu上
 
net.load_state_dict(torch.load('net_params.pth'))#再加载网络的参数
net.eval()

3 注意:

1.只加载网络参数的速度比加载整个网络快得多
2.pth、pkl格式效果相同,ckpt是tensorflow的格式

参考文章

保存加载模型的两种方式

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云雨、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值