Pytorch中 类Parameter的解析,类内成员函数.parameters()的源码分析,参数集合的获取,参数的注册赋值源码分析

本文详细解读了PyTorch中Parameter类型及其与Module的关系,包括如何在模块中注册和管理参数,以及Variable与Parameter的区别。重点介绍了parameters()和named_parameters()函数的工作原理及其实现细节。

类型torch.nn.Parameter

官方解释

  • ParametersVariable的子类。Variable的一种。
  • ParamentersModules一起使用的时候会有一些特殊的属性,即:当Paramenters赋值给Module的属性的时候,他会自动的被加到Module的参数列表中,也就是会出现在parameters()迭代器中。常被用于模块参数module parameter
  • Varibale赋值给Module属性则不会有这样的影响。 这样做的原因是:我们有时候会需要缓存一些临时的状态state, 比如:模型中RNN的最后一个隐状态。如果没有Parameter这个类的话,那么这些临时变量也会注册成为模型变量。

VariableParameter的另一个不同之处在于,Parameter不能被volatile(即:无法设置volatile=True)而且默认requires_grad=TrueVariable默认requires_grad=False

参数说明:

  • data (Tensor) – parameter tensor

  • requires_grad (bool, optional) – 默认为True,在BP的过程中会对其求微分。

类内函数Module.parameters()

源码分析

可以通过Module.parameters()获取网络的参数,那这个函数的实现细节我们通过代码进行分析:

def parameters(self):
    r"""Returns an iterator over module parameters.
        This is typically passed to an optimizer.
        Yields:
            Parameter: module parameter
        Example::
            >>> for param in model.parameters():
            >>>     print(type(param.data), param.size())
            <class 'torch.FloatTensor'> (20L,)
            <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
        """
    for name, param in self.named_parameters():
        yield param

他主要是引用另一个类内成员函数named_parameters(),实现对所有参数的索引包装,生成迭代器,下面看另一个函数:

def named_parameters(self, memo=None, prefix=''):
    r"""Returns an iterator over module parameters, yielding both the
        name of the parameter as well as the parameter itself
        Yields:
            (string, Parameter): Tuple containing the name and parameter
        Example::
            >>> for name, param in self.named_parameters():
            >>>    if name in ['bias']:
            >>>        print(param.size())
        """
    if memo is None:
        memo = set()
    #本身模块的参数
    for name, p in self._parameters.items():
        if p is not None and p not in memo:
            memo.add(p)
            yield prefix + ('.' if prefix else '') + name, p
    for mname, module in self.named_children():
        submodule_prefix = prefix + 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值