测试模型和保存模型
训练后的模型进行保存
使用torch.save()函数
torch.save(state,filename)
state是一个字典(dict)对象,它用于保存模型的状态和参数。它包含以下四个键值对:
- ‘epoch’:当前的轮数(epoch)
- ‘model’:模型的参数(state_dict).
- 'optimizer:优化器的参数(state_dict).
- ‘accuracy’:当前的准确率(accuracy)
这些信息可以用于恢复模型的训练状态,或者评估模型的性能。
file name后后缀一般是".pth"或者".pt"
使用训练好的模型
使用torch.load()函数
torch.load()是一个用来从文件中加载保存的对象的函数,它使用python的pickle模块来进行反序列化。
torch.load()的参数如下:
- f: 一个类似文件的对象(必须实现read(), readline(), tell(), 和 seek()方法),或者一个包含文件名的字符串或os.PathLike对象。
- map_location: 一个函数,torch.device对象,字符串或者字典,用来指定如何重新映射存储位置。
- pickle_module: 用来进行反序列化的模块(必须和序列化时使用的pickle_module相匹配)
- weights_only: 指示反序列化器是否只加载张量、原始类型和字典。
- pickle_load_args: (仅限Python 3)传递给pickle_module.load()和pickle_module.Unpickler()的可选关键字参数,比如errors=…。
torch.load()的返回值是任意类型的对象,取决于保存时的对象。
torch.load()通常用来加载保存的模型或优化器的状态字典(state_dict),这些状态字典是使用torch.save()函数保存的。状态字典是一个Python字典对象,它将每一层映射到其参数张量。你可以使用model.load_state_dict()或optimizer.load_state_dict()方法来加载状态字典,并恢复模型或优化器的状态。
torch.load(f,map_location,pickle_module,weights_only,pickle_load_args)
# torch.load()的使用示例
# 假设我们已经保存了一个模型的状态字典到"model.pth"文件中
# 加载状态字典
state_dict = torch.load("model.pth")
# 创建一个相同结构的模型对象
model = TheModelClass(*args, **kwargs)
# 加载状态字典到模型中
model.load_state_dict(state_dict)
# 如果需要,可以将模型移动到指定设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)