state_dict() -->顾名思义:状态的字典。对了一半,还有一半是所有可学习参数的状态字典
1.定义和作用:
state_dict
是一个从参数名称映射到参数张量的字典对象(dict)。- 它包含了模型的所有可学习参数(如权重和偏置)以及它们的当前值(state)。所以叫state_dict
state_dict
用来保存、加载和转移模型的状态。
2.用法
获取state_dict
:
model = YourModel(...)
state_dict = model.state_dict()
首先要定义自己的模型,毕竟参数源于模型网络,返回一个包含模型所有参数及其值的字典。
保存state_dict
通常,你会将 state_dict
保存到一个文件中,以便之后可以重新加载模型。
torch.save(model.state_dict(), PATH)
加载state_dict
要从文件中加载 state_dict
,首先你需要有一个与保存时相同架构的模型实例。
clone = YourModel(...)
clone.load_state_dict(torch.load(PATH))
clone.eval() # 确保在评估模式下运行
将从文件中加载 state_dict
并将其应用到模型上。
完整示例
定义模型并保存 state_dict
"""加载和保存模型参数"""
# 模型定义
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(20, 256)
self.out = nn.Linear(256,10)
def forward(self, X):
return self.out(F.relu(self.layer1(X)))
net = MLP() # 实例化模型
X = torch.randn(2, 20)
Y = net(X) # 模型前向变换
# print(Y)
# 保存模型参数
torch.save(net.state_dict(), 'mlp.params') # 把MLP所有参数存成一个字典
#加载模型参数
clone = MLP() # 重新声明,实例化模型,相当于一个空壳子,下一步把参数加载进去
clone.load_state_dict(torch.load("mlp.params"))
clone.eval() # 评估模式
print(clone)
# 参数验证(判断前后输出是否相同)
Y2 = clone(X)
print(Y == Y2)
输出:
MLP(
(layer1): Linear(in_features=20, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])