神经网络----网络模型的加载及保存

一、加载网络模型

torch.load()函数:

PyTorch 中用于加载保存的模型或张量的函数

例如:

torch.load(checkpoint_path, map_location=device)

其中, checkpoint_path 是保存模型参数的文件路径,

            map_location=device 用于将模型加载到指定的设备上。如果你在训练时使用了 GPU,并

            且想在 CPU 上进行推断或继续训练,这就很有用。map_location 参数告诉 PyTorch 将模 

            型参数加载到指定的设备上。

这句代码的作用:将路径在checkpoint_path的模型参数文件加载到设备device上面

model.load_state_dict( )函数:

将状态字典加载到模型的方法

例如:

model.load_state_dict(torch.load(checkpoint_path, map_location=device))

这行代码的目的是将从文件加载的预训练模型的状态字典应用到指定的模型 model 中。加载后,model 就包含了预训练模型的参数,可以在之后用于推断或继续训练。

其中,预训练模型的状态字典是由torch.load()函数加载的,model.load_state_dict()函数将状态字典应用到模型model上面。

二、保存网络模型

torch.save() 函数用于将对象保存到文件中,以便之后可以使用 torch.load() 函数加载它

一般用于保存模型、张量、字典等 PyTorch 对象

torch.save( )函数的基本用法:

torch.save(obj, file_path)
  • obj: 要保存的Pytorch对象,可以是模型、张量、字典等
  • file_path: 要保存到的文件路径,可以是相对路径或绝对路径

例如,保存模型参数到文件的代码可能如下所示:

import torch

# 假设 model 是 PyTorch 模型
model = ...

# 假设 file_path 是保存文件的路径
file_path = 'model.pth'

# 使用 torch.save() 保存模型参数到文件
torch.save(model.state_dict(), file_path)

在上述示例中,model.state_dict() 返回模型的参数状态字典,它包含了模型的所有可学习参数。这个字典可以通过 torch.load() 函数加载,用于初始化模型或进行模型的迁移学习等任务。 

例如,

torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))

model.module.state_dict() if hasattr(model, "module") else model.state_dict()

其中,

model.state_dict() 返回模型的当前参数状态字典

有些模型在 多 GPU  或分布式训练中可能使用了  nn.DataParallel  封装,导致模型的顶层包装是 nn.DataParallel 对象。如果是这样,那么需要使用 model.module.state_dict() 来获取实际的模型参数状态字典。这里通过 hasattr(model, "module") 来检查模型是否有 module 属性,如果有,则使用 model.module.state_dict(),否则使用 model.state_dict()

os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))

其中,

os.path.join() 用于拼接文件路径,将模型保存在指定的目录下

output_dir_epoch 是保存模型的目录

{0}_model.pth'.format(epoch) 生成保存文件的名称,其中 {0} 会被替换为当前的 epoch 数字,确保每个模型文件都有唯一的名称,与训练的时期相关

这行代码的作用是将当前模型的参数状态字典保存到一个以 epoch 号命名的文件中

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值