pytorch model.name_parameters()和model.static_dict()

原文链接:https://blog.csdn.net/weixin_41712499/article/details/110198423

在使用pytorch过程中,torch中存在3个功能极其类似的方法,分别是:

model.parameters()、model.named_parameters()和model.state_dict()

下面就具体来说说这三个函数的差异
首先,说说比较接近的model.parameters()和model.named_parameters()。这两者唯一的差别在于,named_parameters()返回的list中,每个元祖打包了2个内容,分别是layer-name和layer-param,而parameters()只有后者。

model.named_parameters()和model.state_dict()间的差别。它们的差异主要体现在3方面:

  • 返回值类型不同
  • 存储的模型参数的种类不同
  • 返回的值的require_grad属性不同

第一个,model.state_dict()是将layer_name : layer_param的键值信息存储为dict形式,而model.named_parameters()则是打包成一个元祖然后再存到list当中;
第二,model.state_dict()存储的是该model中包含的所有layer中的所有参数;而model.named_parameters()则只保存可学习、可被更新的参数,model.buffer()中的参数不包含在model.named_parameters()中
最后,model.state_dict()所存储的模型参数tensor的require_grad属性都是False,而model.named_parameters()的require_grad属性都是True

以上参考:原文链接:https://blog.csdn.net/weixin_41712499/article/details/110198423

比如:

# coding=UTF-8
class net-2(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
		......
    
    def forward(self, des):
        ......


model = net-2()

model.state_dict()

model_dict = model.state_dict()
    for k, v in model_dict.items():
      print(k)



部分输出:
mgp_str.mgp_encoder.position_enc.h_position_encoder
mgp_str.mgp_encoder.position_enc.w_position_encoder
mgp_str.mgp_encoder.position_enc.h_scale.0.weight
mgp_str.mgp_encoder.position_enc.h_scale.0.bias
mgp_str.mgp_encoder.position_enc.h_scale.2.weight
mgp_str.mgp_encoder.position_enc.h_scale.2.bias
mgp_str.mgp_encoder.position_enc.w_scale.0.weight
mgp_str.mgp_encoder.position_enc.w_scale.0.bias
mgp_str.mgp_encoder.position_enc.w_scale.2.weight
mgp_str.mgp_encoder.position_enc.w_scale.2.bias

model.name_parameters()

for k_p, v_p in model.named_parameters():
    print(k_p)




部分输出:
module.mgp_str.mgp_encoder.position_enc.h_scale.0.weight
module.mgp_str.mgp_encoder.position_enc.h_scale.0.bias
module.mgp_str.mgp_encoder.position_enc.h_scale.2.weight
module.mgp_str.mgp_encoder.position_enc.h_scale.2.bias
module.mgp_str.mgp_encoder.position_enc.w_scale.0.weight
module.mgp_str.mgp_encoder.position_enc.w_scale.0.bias
module.mgp_str.mgp_encoder.position_enc.w_scale.2.weight
module.mgp_str.mgp_encoder.position_enc.w_scale.2.bias

model_dict=model.state_dict()  #这里的model是上面例子搭建的模型

model_dict=torch.load('***.pth', map_location='cpu')

是一样的功能,输出的参数都是一样的

如果想加载模型参数,如下:

model.load_state_dict(model_dict, strict=False)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值