torch.save torch.load 四种使用方式 如何加载模型 如何加载模型参数 如何保存模型 如何保存模型参数

在 PyTorch 中,我们可以使用 torch.save 函数将模型或张量保存到文件中,使用 torch.load 函数从文件中加载模型或张量。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model.state_dict(), 'model.pth')

在上面的代码中,我们使用 model.state_dict() 函数将模型的参数保存为一个字典,并使用 torch.save 函数将字典保存到名为 'model.pth' 的文件中。如果需要保存整个模型,可以使用 torch.save(model, 'model.pth') 函数保存模型。

加载模型

	import torch
    # 定义模型
    model = ...
    # 加载模型
    model.load_state_dict(torch.load('model.pth'))

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载模型的参数字典,并使用 model.load_state_dict 函数加载参数字典到模型中。如果需要加载整个模型,可以使用 model = torch.load('model.pth') 函数加载模型。

保存张量

	import torch
    # 定义张量
    tensor = ...
    # 保存张量
    torch.save(tensor, 'tensor.pth')

在上面的代码中,我们使用 torch.save 函数将张量保存到名为 'tensor.pth' 的文件中。

加载张量

	import torch
    # 加载张量
    tensor = torch.load('tensor.pth')

在上面的代码中,我们使用 torch.load 函数从名为 'tensor.pth' 的文件中加载张量。

如果使用 torch.save(model) 函数保存整个模型,可以使用 torch.load 函数直接加载整个模型。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model, 'model.pth')

在上面的代码中,我们使用 torch.save 函数将整个模型保存到名为 'model.pth' 的文件中。

加载模型

	import torch
    # 加载模型
    model = torch.load('model.pth')

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载整个模型。需要注意的是,如果模型是在 GPU 上训练的,加载模型时需要使用 map_location 参数将模型映射到 CPU 上:

	import torch
    # 加载模型
    model = torch.load('model.pth', map_location=torch.device('cpu'))

如果模型是在 GPU 上训练的,而且需要将模型加载到指定的 GPU 上,可以使用 torch.cuda.device 函数切换到指定的 GPU,然后将模型加载到该 GPU 上:

	import torch
    # 切换到指定的 GPU
    torch.cuda.device(1)
    # 加载模型
    model = torch.load('model.pth', map_location=torch.device('cuda:1'))

在上面的代码中,我们使用 torch.cuda.device 函数切换到索引为 1 的 GPU,然后将模型加载到该 GPU 上。

如果使用 torch.save(model) 函数保存模型,加载模型时可以使用 model.load_state_dict 函数只加载模型的参数。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model, 'model.pth')

在上面的代码中,我们使用 torch.save 函数将整个模型保存到名为 'model.pth' 的文件中。

加载模型参数

	import torch
    # 定义模型
    model = ...
    # 加载模型参数
    model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载整个模型,并使用 model.load_state_dict 函数将加载的参数字典加载到模型中。需要注意的是,如果模型是在 GPU 上训练的,加载模型时需要使用 map_location 参数将模型映射到 CPU 上。如果模型在 GPU 上训练并且需要加载到指定的 GPU 上,请参考前面的回答。

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UCAS_V

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值