模型参数常见读取与保存操作

文章介绍了如何使用torch.save存储和torch.load加载模型参数,包括张量的值、形状、数据类型以及梯度信息。在保存模型时,通常会保存state_dict,而加载时使用load_state_dict方法。还提到了在恢复预训练模型时的处理流程,包括加载checkpoint,解析并应用到模型和优化器上。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

简洁版torch.save和torch.load模型参数

	# save
    torch.save(model.state_dict(), "/home/yingmuzhi/microDL_2_0/a.pth")

	# load
	model = UNet()
    model.load_state_dict(torch.load("/home/yingmuzhi/microDL_2_0/a.pth"))

如何使用torch.save存储梯度信息

默认情况下,torch.save不会保存张量的梯度信息。它只会保存张量的值、形状和数据类型。

如果要保存梯度信息,可以在调用torch.save时将张量的requires_grad属性设置为True。这样,保存的文件中将包含梯度信息。

例如:

import torch

# 创建一个需要梯度的张量
x = torch.randn(3, 4, requires_grad=True)

# 计算一些操作
y = 2 * x + 1
z = y.mean()

# 计算梯度
z.backward()

# 保存张量和梯度信息
torch.save(x, 'tensor.pt')

# 加载张量和梯度信息
loaded_x = torch.load('tensor.pt')
print(loaded_x)

在这个例子中,x张量的梯度信息被保存在文件’tensor.pt’中。加载张量时,梯度信息也将被恢复。

模型参数导入存储模板

模型参数常见读取与保存操作,代码见下:

'''
经常存储的会有

- 模型参数, 包括权重weights和偏置bias, 即model.params.weight和model.params.bias
- optimizer
- epoch迭代次数
- ...

我们的处理流程往往是:
先用torch.load()导入, 读成checkpoint
将checkpoint分步骤解析, 如有的可能需要使用load_state_dict()来读取
将需要存储的东西以一个字典的形式组装成checkpoint, 其中可能有些需要存储的东西要使用state_dict()来存储
使用torch.save()将checkpoint存储

... 表示无实意的pass或者: 
参考`https://zhuanlan.zhihu.com/p/264896206`

'''
import os, torch


SAVE_DICTIONARY: dict = {
    "model": {},
    "optimizer": {},
    "start_epoch": 1,
    "args": [],
}



def main(

):
    ...
    
    resume = '',
    model = None,
    optimizer = None,
    # load pre-trained model
    if os.path.exists(resume):
        checkpoint = torch.load(resume, map_location="cpu") # 先用torch.load()全部导入, 读成checkpoint再后续分解
        model.load_state_dict(checkpoint["model"])          # 有些要存储成torch的字典形式的, 就必须使用.load_state_dict()来将dict解压 和 用.state_dict()来压缩成dict 
        optimizer.load_state_dict(checkpoint["optimizer"])
        # lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        start_epoch = checkpoint["epoch"] + 1
        print("load pre-trained model successfully!")
    else:
        print("load pre-trained model failed.")
    
    ...

    epoch = 1
    args = []
    # save model
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        # "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch, 
        "args": args
    }                                                       # 先组装成字典checkpoint, 再使用torch.save()全部存储
    torch.save(checkpoint, resume)

    ...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值