一、加载网络模型
torch.load()函数:
PyTorch 中用于加载保存的模型或张量的函数
例如:
torch.load(checkpoint_path, map_location=device)
其中, checkpoint_path
是保存模型参数的文件路径,
map_location=device
用于将模型加载到指定的设备上。如果你在训练时使用了 GPU,并
且想在 CPU 上进行推断或继续训练,这就很有用。map_location
参数告诉 PyTorch 将模
型参数加载到指定的设备上。
这句代码的作用:将路径在checkpoint_path的模型参数文件加载到设备device上面
model.load_state_dict( )函数:
将状态字典加载到模型的方法
例如:
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
这行代码的目的是将从文件加载的预训练模型的状态字典应用到指定的模型 model
中。加载后,model
就包含了预训练模型的参数,可以在之后用于推断或继续训练。
其中,预训练模型的状态字典是由torch.load()函数加载的,model.load_state_dict()函数将状态字典应用到模型model上面。
二、保存网络模型
torch.save()
函数用于将对象保存到文件中,以便之后可以使用 torch.load()
函数加载它
一般用于保存模型、张量、字典等 PyTorch 对象
torch.save( )函数的基本用法:
torch.save(obj, file_path)
obj
: 要保存的Pytorch对象,可以是模型、张量、字典等file_path
: 要保存到的文件路径,可以是相对路径或绝对路径
例如,保存模型参数到文件的代码可能如下所示:
import torch
# 假设 model 是 PyTorch 模型
model = ...
# 假设 file_path 是保存文件的路径
file_path = 'model.pth'
# 使用 torch.save() 保存模型参数到文件
torch.save(model.state_dict(), file_path)
在上述示例中,model.state_dict()
返回模型的参数状态字典,它包含了模型的所有可学习参数。这个字典可以通过 torch.load()
函数加载,用于初始化模型或进行模型的迁移学习等任务。
例如,
torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))
model.module.state_dict() if hasattr(model, "module") else model.state_dict()
:
其中,
model.state_dict()
返回模型的当前参数状态字典
有些模型在 多 GPU 或分布式训练中可能使用了 nn.DataParallel
封装,导致模型的顶层包装是 nn.DataParallel
对象。如果是这样,那么需要使用 model.module.state_dict()
来获取实际的模型参数状态字典。这里通过 hasattr(model, "module")
来检查模型是否有 module
属性,如果有,则使用 model.module.state_dict()
,否则使用 model.state_dict()
os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))
:
其中,
os.path.join()
用于拼接文件路径,将模型保存在指定的目录下
output_dir_epoch
是保存模型的目录
{0}_model.pth'.format(epoch)
生成保存文件的名称,其中 {0}
会被替换为当前的 epoch
数字,确保每个模型文件都有唯一的名称,与训练的时期相关
这行代码的作用是将当前模型的参数状态字典保存到一个以 epoch 号命名的文件中