【torch】parameters与named_parameters的区别

【torch】parameters与named_parameters的区别

前言

为了详细的查看网络的结构参数等,因此本文研究一下 parameters()named_parameters 的区别。

此示例属于从 nn.Module 中继承的成员函数。函数位于:[python环境路径]/lib/python3.8/site-packages/torch/nn/modules/module.py 文件中。

在这里插入图片描述

不要把时间浪费在无法控制的事情上。生命中最幸福的时刻之一是当你有勇气放下你无法改变的事情。

简要解释说明

从名称上看,named_parametersparameters 多了个 named ,已经能够显示出本质区别来了。

从返回的对象类型来看,parameters 返回的参数类型,named_parameters 返回的是元组,长度为2。其中第一维度为名称, 第二维度为参数类型。

parameters()named_parameters() 都是 PyTorch 中用于获取模型参数的函数,但它们有一些区别:

  • parameters() 返回一个生成器,用于迭代模型中的所有参数,每个参数都是一个 torch.nn.Parameter 类型对象;
  • named_parameters() 也返回一个生成器,用于迭代模型中的所有参数,但是每个参数都是一个元组,包含参数名称和 torch.nn.Parameter 类型对象。

完整代码示例

为了方便阐述,现提供完整的代码示例以增加读者对本质差异的深刻理解。该示例代码已在本地环境下进行了测试验证,如有任何问题,请在评论区留言并提供错误说明,以供改进。

import torch.nn as nn
import torch.nn.init as init

import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm


class BaseClassifier(nn.Module):

    def fresh_params(self, bn_wd):
        if bn_wd:
            return self.parameters()
        else:
            return self.named_parameters()
        
        
class MyClassifier(BaseClassifier):

    def __init__(self):
        super(MyClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.bn1 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(4, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


model = MyClassifier()

# 返回模型的所有参数
params = model.fresh_params(True)

# 返回模型的所有命名 和 参数
named_params = model.fresh_params(False)

print(f"params: {params}, named_params: {named_params}")

for n in params:
    print(f"type(n): {type(n)}, n.shape: {n.shape}")
    
for n in named_params:
    print(f"type(n): {type(n)}, len(n): {len(n)}, n[0]: {n[0]}, n[1].shape: {n[1].shape}")

输出:

params: <generator object Module.parameters at 0x7f4f93399e40>, named_params: <generator object Module.named_parameters at 0x7f4f93399eb0>
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([16, 3, 3, 3])
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([16])
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([16])
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([16])
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([2, 4])
type(n): <class 'torch.nn.parameter.Parameter'>, n.shape: torch.Size([2])
type(n): <class 'tuple'>, len(n): 2, n[0]: conv1.weight, n[1].shape: torch.Size([16, 3, 3, 3])
type(n): <class 'tuple'>, len(n): 2, n[0]: conv1.bias, n[1].shape: torch.Size([16])
type(n): <class 'tuple'>, len(n): 2, n[0]: bn1.weight, n[1].shape: torch.Size([16])
type(n): <class 'tuple'>, len(n): 2, n[0]: bn1.bias, n[1].shape: torch.Size([16])
type(n): <class 'tuple'>, len(n): 2, n[0]: fc1.weight, n[1].shape: torch.Size([2, 4])
type(n): <class 'tuple'>, len(n): 2, n[0]: fc1.bias, n[1].shape: torch.Size([2])

截图:
在这里插入图片描述
可以看到,named_parameters() 函数返回的不仅是参数本身,还包含了参数的名称。若需要获取参数的名称,建议使用 named_parameters() 函数。否则,可以使用 parameters() 函数。

参考文献

示例中的代码部分参考论文:
[1] JIA J, HUANG H, CHEN X, 等. Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting[M/OL]. arXiv, 2021[2023-06-30]. http://arxiv.org/abs/2107.03576. DOI:10.48550/arXiv.2107.03576.
代码位于:
@misc{BibEntry2023Oct,
title = {{Rethinking{ _ \_ _}of{ _ \_ _}PAR}},
year = {2023},
month = oct,
urldate = {2023-10-05},
language = {english},
note = {[Online; accessed 5. Oct. 2023]},
url = {https://github.com/valencebond/Rethinking_of_PAR}
}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

或许,这就是梦想吧!

如果对你有用,欢迎打赏。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值