测试模型和保存模型

测试模型和保存模型

训练后的模型进行保存

使用torch.save()函数

torch.save(state,filename)

state是一个字典(dict)对象,它用于保存模型的状态和参数。它包含以下四个键值对:

  • ‘epoch’:当前的轮数(epoch)
  • ‘model’:模型的参数(state_dict).
  • 'optimizer:优化器的参数(state_dict).
  • ‘accuracy’:当前的准确率(accuracy)

这些信息可以用于恢复模型的训练状态,或者评估模型的性能。

file name后后缀一般是".pth"或者".pt"

使用训练好的模型

使用torch.load()函数

torch.load()是一个用来从文件中加载保存的对象的函数,它使用python的pickle模块来进行反序列化。

torch.load()的参数如下:

  • f: 一个类似文件的对象(必须实现read(), readline(), tell(), 和 seek()方法),或者一个包含文件名的字符串或os.PathLike对象。
  • map_location: 一个函数,torch.device对象,字符串或者字典,用来指定如何重新映射存储位置。
  • pickle_module: 用来进行反序列化的模块(必须和序列化时使用的pickle_module相匹配)
  • weights_only: 指示反序列化器是否只加载张量、原始类型和字典。
  • pickle_load_args: (仅限Python 3)传递给pickle_module.load()和pickle_module.Unpickler()的可选关键字参数,比如errors=…。

torch.load()的返回值是任意类型的对象,取决于保存时的对象。

torch.load()通常用来加载保存的模型或优化器的状态字典(state_dict),这些状态字典是使用torch.save()函数保存的。状态字典是一个Python字典对象,它将每一层映射到其参数张量。你可以使用model.load_state_dict()或optimizer.load_state_dict()方法来加载状态字典,并恢复模型或优化器的状态。

torch.load(f,map_location,pickle_module,weights_only,pickle_load_args)
# torch.load()的使用示例
# 假设我们已经保存了一个模型的状态字典到"model.pth"文件中
# 加载状态字典
state_dict = torch.load("model.pth")
# 创建一个相同结构的模型对象
model = TheModelClass(*args, **kwargs)
# 加载状态字典到模型中
model.load_state_dict(state_dict)
# 如果需要,可以将模型移动到指定设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值