类型torch.nn.Parameter
官方解释
Parameters是Variable的子类。Variable的一种。Paramenters和Modules一起使用的时候会有一些特殊的属性,即:当Paramenters赋值给Module的属性的时候,他会自动的被加到Module的参数列表中,也就是会出现在parameters()迭代器中。常被用于模块参数module parameter。- 将
Varibale赋值给Module属性则不会有这样的影响。 这样做的原因是:我们有时候会需要缓存一些临时的状态state, 比如:模型中RNN的最后一个隐状态。如果没有Parameter这个类的话,那么这些临时变量也会注册成为模型变量。
Variable与Parameter的另一个不同之处在于,Parameter不能被volatile(即:无法设置volatile=True)而且默认requires_grad=True。Variable默认requires_grad=False。
参数说明:
-
data (Tensor) – parameter tensor
-
requires_grad (bool, optional) – 默认为True,在BP的过程中会对其求微分。
类内函数Module.parameters()
源码分析
可以通过Module.parameters()获取网络的参数,那这个函数的实现细节我们通过代码进行分析:
def parameters(self):
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Yields:
Parameter: module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters():
yield param
他主要是引用另一个类内成员函数named_parameters(),实现对所有参数的索引包装,生成迭代器,下面看另一个函数:
def named_parameters(self, memo=None, prefix=''):
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself
Yields:
(string, Parameter): Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
"""
if memo is None:
memo = set()
#本身模块的参数
for name, p in self._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children():
submodule_prefix = prefix +

本文详细解读了PyTorch中Parameter类型及其与Module的关系,包括如何在模块中注册和管理参数,以及Variable与Parameter的区别。重点介绍了parameters()和named_parameters()函数的工作原理及其实现细节。
最低0.47元/天 解锁文章
5091

被折叠的 条评论
为什么被折叠?



