测试模型和保存模型

本文介绍了在PyTorch中如何使用torch.save()函数保存模型和训练状态,包括模型参数、优化器状态及精度等信息。保存的文件通常以.pth或.pt为扩展名。之后,利用torch.load()可以加载这些保存的对象,通过model.load_state_dict()恢复模型的状态。此外,还讨论了如何根据设备选择加载模型的位置。
摘要由CSDN通过智能技术生成

测试模型和保存模型

训练后的模型进行保存

使用torch.save()函数

torch.save(state,filename)

state是一个字典(dict)对象,它用于保存模型的状态和参数。它包含以下四个键值对:

  • ‘epoch’:当前的轮数(epoch)
  • ‘model’:模型的参数(state_dict).
  • 'optimizer:优化器的参数(state_dict).
  • ‘accuracy’:当前的准确率(accuracy)

这些信息可以用于恢复模型的训练状态,或者评估模型的性能。

file name后后缀一般是".pth"或者".pt"

使用训练好的模型

使用torch.load()函数

torch.load()是一个用来从文件中加载保存的对象的函数,它使用python的pickle模块来进行反序列化。

torch.load()的参数如下:

  • f: 一个类似文件的对象(必须实现read(), readline(), tell(), 和 seek()方法),或者一个包含文件名的字符串或os.PathLike对象。
  • map_location: 一个函数,torch.device对象,字符串或者字典,用来指定如何重新映射存储位置。
  • pickle_module: 用来进行反序列化的模块(必须和序列化时使用的pickle_module相匹配)
  • weights_only: 指示反序列化器是否只加载张量、原始类型和字典。
  • pickle_load_args: (仅限Python 3)传递给pickle_module.load()和pickle_module.Unpickler()的可选关键字参数,比如errors=…。

torch.load()的返回值是任意类型的对象,取决于保存时的对象。

torch.load()通常用来加载保存的模型或优化器的状态字典(state_dict),这些状态字典是使用torch.save()函数保存的。状态字典是一个Python字典对象,它将每一层映射到其参数张量。你可以使用model.load_state_dict()或optimizer.load_state_dict()方法来加载状态字典,并恢复模型或优化器的状态。

torch.load(f,map_location,pickle_module,weights_only,pickle_load_args)
# torch.load()的使用示例
# 假设我们已经保存了一个模型的状态字典到"model.pth"文件中
# 加载状态字典
state_dict = torch.load("model.pth")
# 创建一个相同结构的模型对象
model = TheModelClass(*args, **kwargs)
# 加载状态字典到模型中
model.load_state_dict(state_dict)
# 如果需要,可以将模型移动到指定设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值