原文链接:https://blog.csdn.net/weixin_41712499/article/details/110198423
在使用pytorch过程中,torch中存在3个功能极其类似的方法,分别是:
model.parameters()、model.named_parameters()和model.state_dict()
下面就具体来说说这三个函数的差异
首先,说说比较接近的model.parameters()和model.named_parameters()。这两者唯一的差别在于,named_parameters()返回的list中,每个元祖打包了2个内容,分别是layer-name和layer-param,而parameters()只有后者。
model.named_parameters()和model.state_dict()间的差别。它们的差异主要体现在3方面:
- 返回值类型不同
- 存储的模型参数的种类不同
- 返回的值的require_grad属性不同
第一个,model.state_dict()是将layer_name : layer_param的键值信息存储为dict形式,而model.named_parameters()则是打包成一个元祖然后再存到list当中;
第二,model.state_dict()存储的是该model中包含的所有layer中的所有参数;而model.named_parameters()则只保存可学习、可被更新的参数,model.buffer()中的参数不包含在model.named_parameters()中
最后,model.state_dict()所存储的模型参数tensor的require_grad属性都是False,而model.named_parameters()的require_grad属性都是True
以上参考:原文链接:https://blog.csdn.net/weixin_41712499/article/details/110198423
比如:
# coding=UTF-8
class net-2(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
......
def forward(self, des):
......
model = net-2()
model.state_dict()
model_dict = model.state_dict()
for k, v in model_dict.items():
print(k)
部分输出:
mgp_str.mgp_encoder.position_enc.h_position_encoder
mgp_str.mgp_encoder.position_enc.w_position_encoder
mgp_str.mgp_encoder.position_enc.h_scale.0.weight
mgp_str.mgp_encoder.position_enc.h_scale.0.bias
mgp_str.mgp_encoder.position_enc.h_scale.2.weight
mgp_str.mgp_encoder.position_enc.h_scale.2.bias
mgp_str.mgp_encoder.position_enc.w_scale.0.weight
mgp_str.mgp_encoder.position_enc.w_scale.0.bias
mgp_str.mgp_encoder.position_enc.w_scale.2.weight
mgp_str.mgp_encoder.position_enc.w_scale.2.bias
model.name_parameters()
for k_p, v_p in model.named_parameters():
print(k_p)
部分输出:
module.mgp_str.mgp_encoder.position_enc.h_scale.0.weight
module.mgp_str.mgp_encoder.position_enc.h_scale.0.bias
module.mgp_str.mgp_encoder.position_enc.h_scale.2.weight
module.mgp_str.mgp_encoder.position_enc.h_scale.2.bias
module.mgp_str.mgp_encoder.position_enc.w_scale.0.weight
module.mgp_str.mgp_encoder.position_enc.w_scale.0.bias
module.mgp_str.mgp_encoder.position_enc.w_scale.2.weight
module.mgp_str.mgp_encoder.position_enc.w_scale.2.bias
model_dict=model.state_dict() #这里的model是上面例子搭建的模型
与
model_dict=torch.load('***.pth', map_location='cpu')
是一样的功能,输出的参数都是一样的
如果想加载模型参数,如下:
model.load_state_dict(model_dict, strict=False)