state_dict 和 named_parameters 的区别
首先,让我们搞清楚两个重要的概念:state_dict 和 named_parameters。
state_dict 是什么?
state_dict 是一个字典对象,它保存了模型的所有参数和缓冲区。它是模型持久化的好帮手,可以用来保存和加载模型的权重。
state_dict = model.state_dict()
named_parameters 是什么?
named_parameters 返回的是一个生成器,包含模型中所有参数的名字和参数本身。它是我们检查和操作模型参数的好工具。
for name, param in model.named_parameters():
print(name, param.shape)
为什么 state_dict 不能直接用来检查 requires_grad?
state_dict 只保存参数的值,而不保存参数的 requires_grad 属性。换句话说,它不知道哪些参数是可训练的,哪些不是。这就像你去超市买东西,购物清单上只有商品的名字和价格,但没有标明哪些是打折商品。
如何正确保存可训练的参数?
为了确保我们只保存那些可训练的参数,我们需要结合 named_parameters 和 state_dict。下面是一个简单的示例,展示了如何正确地保存和打印可训练的参数。
示例代码1
import torch
import torch.nn as nn
# 假设你有一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = SimpleModel()
# 打印所有可训练的参数
print('打印所有可训练的参数')
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
# 保存可训练的参数
print('保存可训练的参数')
trainable_params = {name: param for name, param in model.named_parameters() if param.requires_grad}
for name, param in trainable_params.items():
print(name)
logger.info('可训练的参数:')
model.print_trainable_parameters()
for name, param in model.named_parameters():
if param.requires_grad:
if 'lora' in name:
logger.info(f"LoRA parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
elif 'classifier' in name:
logger.info(f"Classifier parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
else:
logger.info(f"Other trainable parameter: {name}, Shape: {param.shape}, Total: {param.numel()}")
print('保存可训练的参数')
for k, v in model.state_dict().items():
if v.requires_grad:
print(k)
trainable_params = {k: v for k, v in model.state_dict().items() if v.requires_grad}
torch.save(trainable_params, 'trainable_params.pth')
示例代码2
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
self.bn = nn.BatchNorm1d(20)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleModel()
# 打印可训练参数(使用 named_parameters)
print("=======Trainable parameters using named_parameters==========")
for name, param in model.named_parameters() :
print(f"Parameter name: {name}, requires_grad: {param.requires_grad}")
# 获取所有参数和缓冲区(使用 state_dict)
state_dict = model.state_dict()
# 打印所有参数和缓冲区(使用 state_dict)
print("\n=======All parameters and buffers using state_dict==========")
for name, param in state_dict.items():
print(f"Parameter name: {name}, requires_grad: {param.requires_grad}")
运行输出:
=Trainable parameters using named_parameters====
Parameter name: fc1.weight, requires_grad: True
Parameter name: fc1.bias, requires_grad: True
Parameter name: fc2.weight, requires_grad: True
Parameter name: fc2.bias, requires_grad: True
Parameter name: bn.weight, requires_grad: True
Parameter name: bn.bias, requires_grad: True
=All parameters and buffers using state_dict====
Parameter name: fc1.weight, requires_grad: False
Parameter name: fc1.bias, requires_grad: False
Parameter name: fc2.weight, requires_grad: False
Parameter name: fc2.bias, requires_grad: False
Parameter name: bn.weight, requires_grad: False
Parameter name: bn.bias, requires_grad: False
Parameter name: bn.running_mean, requires_grad: False
Parameter name: bn.running_var, requires_grad: False
Parameter name: bn.num_batches_tracked, requires_grad: False
解释
- 定义模型:我们定义了一个简单的两层全连接网络。
- 打印可训练参数:使用 named_parameters 打印所有 requires_grad 为 True 的参数。
- 保存可训练参数:创建一个字典,只包含可训练的参数。
结论
通过这个简单的示例,我们可以看到,虽然 state_dict 是保存和加载模型权重的好工具,但它并不包含 requires_grad 信息。因此,我们需要使用 named_parameters 来确保只保存那些可训练的参数。
希望这篇博客能帮你避开这个小坑,让你的模型保存工作更加顺利。