保存和加载训练模型--torch.save()和torch.load()

介绍两种加载和调用模型的方法,保存的文件格式均为pth文件。

需要注意的是,在加载自己保存的网络时,文件中需要有类的定义(import导入也可)

# 保存方式有两种,分别对应不同的加载方式,且文件类型均为保存为pth文件

# 方式一:
# 保存和加载自己的模型
import torch
import torch.nn as nn


# 定义一个卷积网络
import torchvision


class My_nn_s(nn.Module):
    def __init__(self):
        super(My_nn_s, self).__init__()
        self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
                                   nn.MaxPool2d(2),
                                   )

    def forward(self, x):
        x = self.model(x)
        return x

# 保存模型
my_nn_s = My_nn_s()
torch.save(my_nn_s, 'saved1.pth')  # 注意,保存的对象是实例化的类

# 加载模型
loadnn = torch.load('saved1.pth')
print(loadnn)

# 需要注意的是,这种加载模式类的定义,即My_nn_s的定义(10-19行代码)需要和加载代码(26-27行)位于同一个文件,
# 当不在同一文件时可以通过import方式引入

# 方式一:
# 保存和加载torchvision中已有模型
model = torchvision.models.alexnet(pretrained=False)
torch.save(model, 'saved_alex.pth')

# 加载模型
model_load = torch.load('saved_alex.pth')
print(model_load)

# 注意,当加载torchvision中已有模型时,不同于保存自己定义的网络模型,文件中无需类的定义

# 方式二(官方推荐的方法):
# 仍以自己定义的模型为例:
# 保存模型
my_nn_s1 = My_nn_s()
torch.save(my_nn_s1.state_dict(), 'saved2.pth')  # 注意,保存的对象是实例化的类的状态字典state_dict

# 加载模型
dict_load = torch.load('saved2.pth')  # 先从文件中取出字典
my_nn_new = My_nn_s()  # 然后实例化
my_nn_s.load_state_dict(dict_load)  # 最后加载字典
print(my_nn_new)


# 从以上可知方法二比较麻烦,但是保存的文件较小。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值