torch.save的用法

介绍

torch.save 是 PyTorch 中用于保存对象(如模型、张量、字典等)的函数。它可以将数据序列化并保存到文件中,方便后续加载和使用。

基本用法

torch.save(obj, f)

参数说明:

  • obj:要保存的对象,可以是模型、张量、字典等。
  • f:保存的目标文件路径,可以是:
    • 文件路径字符串(如 ‘model.pth’)。
    • 一个文件对象(如 open(‘model.pth’, ‘wb’))。
    • 一个 torch.ByteIO 对象(用于保存到内存中)。

常见用法

保存张量

import torch  

# 创建一个张量  
tensor = torch.tensor([1, 2, 3, 4])  

# 保存张量到文件  
torch.save(tensor, 'tensor.pth')  

# 加载张量  
loaded_tensor = torch.load('tensor.pth')  
print(loaded_tensor)  # 输出:tensor([1, 2, 3, 4])

保存模型的参数

保存模型的参数(state_dict)是 PyTorch 推荐的保存模型的方式,因为它只保存模型的权重和偏置,而不保存整个模型结构。

import torch  
import torch.nn as nn  

# 定义一个简单的模型  
model = nn.Linear(10, 1)  

# 保存模型的参数  
torch.save(model.state_dict(), 'model.pth')  

# 加载模型的参数  
model2 = nn.Linear(10, 1)  # 需要重新定义模型结构  
model2.load_state_dict(torch.load('model.pth'))  
print(model2.state_dict())  # 输出模型的参数

保存整个模型(不推荐)

可以直接保存整个模型(包括模型结构和参数),但这种方式依赖于保存时的代码环境,可能在不同版本的 PyTorch 或不同的代码结构中无法加载。

# 保存整个模型  
torch.save(model, 'entire_model.pth')  

# 加载整个模型  
loaded_model = torch.load('entire_model.pth')  
print(loaded_model)

注意事项

  • 推荐保存 state_dict
    • 保存 state_dict(模型参数)比保存整个模型更灵活,因为它不依赖于保存时的代码环境。
    • 加载时需要重新定义模型结构,然后加载参数。
  • 文件扩展名
    • 通常使用 .pth 或 .pt 作为保存文件的扩展名,但这只是约定俗成,PyTorch 并不强制要求。
  • GPU 和 CPU 的兼容性
    • 如果保存的模型是在 GPU 上,但加载时在 CPU 上,需要显式指定 map_location 参数。

将网络模型保存到文件中,这将保存网络的结构和参数

torch.save(net, ‘./data/net.pkl’)

将网络的状态字典保存到文件中,状态字典包含了网络的参数

torch.save(net.state_dict(), ‘./data/net_params.pkl’)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值