保存模型可训练参数的正确姿势:state_dict vs named_parameters

state_dict 和 named_parameters 的区别

首先,让我们搞清楚两个重要的概念:state_dict 和 named_parameters。

state_dict 是什么?

state_dict 是一个字典对象,它保存了模型的所有参数和缓冲区。它是模型持久化的好帮手,可以用来保存和加载模型的权重。

state_dict = model.state_dict()

named_parameters 是什么?

named_parameters 返回的是一个生成器,包含模型中所有参数的名字和参数本身。它是我们检查和操作模型参数的好工具。

for name, param in model.named_parameters():
    print(name, param.shape)

为什么 state_dict 不能直接用来检查 requires_grad?

state_dict 只保存参数的值,而不保存参数的 requires_grad 属性。换句话说,它不知道哪些参数是可训练的,哪些不是。这就像你去超市买东西,购物清单上只有商品的名字和价格,但没有标明哪些是打折商品。

如何正确保存可训练的参数?

为了确保我们只保存那些可训练的参数,我们需要结合 named_parameters 和 state_dict。下面是一个简单的示例,展示了如何正确地保存和打印可训练的参数。

示例代码1

import torch
import torch.nn as nn

# 假设你有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = SimpleModel()

# 打印所有可训练的参数
print('打印所有可训练的参数')
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")

# 保存可训练的参数
print('保存可训练的参数')
trainable_params = {name: param for name, param in model.named_parameters() if param.requires_grad}
for name, param in trainable_params.items():
    print(name)

logger.info('可训练的参数:')
model.print_trainable_parameters()
for name, param in model.named_parameters():
    if param.requires_grad:
        if 'lora' in name:
            logger.info(f"LoRA parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
        elif 'classifier' in name:
            logger.info(f"Classifier parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
        else:
            logger.info(f"Other trainable parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")

print('保存可训练的参数')
for k, v in model.state_dict().items():
    if v.requires_grad:
        print(k)
trainable_params = {k: v for k, v in model.state_dict().items() if v.requires_grad}
torch.save(trainable_params, 'trainable_params.pth')

在这里插入图片描述
在这里插入图片描述

示例代码2

import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)
        self.bn = nn.BatchNorm1d(20)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = SimpleModel()

# 打印可训练参数(使用 named_parameters)
print("=======Trainable parameters using named_parameters==========")
for name, param in model.named_parameters() :
    print(f"Parameter name: {name}, requires_grad: {param.requires_grad}")


# 获取所有参数和缓冲区(使用 state_dict)
state_dict = model.state_dict()
# 打印所有参数和缓冲区(使用 state_dict)
print("\n=======All parameters and buffers using state_dict==========")
for name, param in state_dict.items():
    print(f"Parameter name: {name}, requires_grad: {param.requires_grad}")

运行输出:
=Trainable parameters using named_parameters====
Parameter name: fc1.weight, requires_grad: True
Parameter name: fc1.bias, requires_grad: True
Parameter name: fc2.weight, requires_grad: True
Parameter name: fc2.bias, requires_grad: True
Parameter name: bn.weight, requires_grad: True
Parameter name: bn.bias, requires_grad: True

=All parameters and buffers using state_dict====
Parameter name: fc1.weight, requires_grad: False
Parameter name: fc1.bias, requires_grad: False
Parameter name: fc2.weight, requires_grad: False
Parameter name: fc2.bias, requires_grad: False
Parameter name: bn.weight, requires_grad: False
Parameter name: bn.bias, requires_grad: False
Parameter name: bn.running_mean, requires_grad: False
Parameter name: bn.running_var, requires_grad: False
Parameter name: bn.num_batches_tracked, requires_grad: False

解释

  • 定义模型:我们定义了一个简单的两层全连接网络。
  • 打印可训练参数:使用 named_parameters 打印所有 requires_grad 为 True 的参数。
  • 保存可训练参数:创建一个字典,只包含可训练的参数。

结论

通过这个简单的示例,我们可以看到,虽然 state_dict 是保存和加载模型权重的好工具,但它并不包含 requires_grad 信息。因此,我们需要使用 named_parameters 来确保只保存那些可训练的参数。

希望这篇博客能帮你避开这个小坑,让你的模型保存工作更加顺利。

  • 9
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值