ParameterList学习
在构造网络模型过程中,经常定义一些模型参数,诸如self.W_ih
, self.W_hh
, 和 self.b_h
都被转换为 torch.nn.ParameterList
,而其中的每个权重或偏置被包装成 torch.nn.Parameter
。这个实现的方式可以通过以下几点来理解:
1. 为什么要使用 torch.nn.ParameterList
?
torch.nn.ParameterList
是 PyTorch 中的一个特殊容器类,它用于存储一系列的 torch.nn.Parameter
对象。与普通的 Python 列表不同,ParameterList
是 PyTorch 模型的一部分,意味着它内部的 Parameter
对象会被模型自动识别,并参与模型的优化过程。
- 普通列表 vs
ParameterList
:
如果你使用普通的 Python 列表(例如list
)来存储这些参数,PyTorch 不会自动将它们注册为模型的参数,这意味着这些参数不会在训练过程中更新。而使用ParameterList
,则确保这些参数能够被自动管理,且在模型的forward
方法中能够被访问和优化。
示例解释:
# 普通 Python 列表,PyTorch 不会跟踪这些参数
self.W_ih = [torch.randn(hidden_size, input_size) for _ in range(num_layers)]
此代码生成一个普通的 Python 列表,虽然包含了 torch.Tensor
对象,但它们不会被自动注册为模型的参数,也就无法被 optimizer
自动更新。
为了使这些参数在训练过程中能被追踪和更新,你需要将它们转换为 torch.nn.Parameter
对象,并使用 ParameterList
来存储它们:
self.W_ih = torch.nn.ParameterList([torch.nn.Parameter(w) for w in self.W_ih])
这样,self.W_ih
就成为了模型的一部分,并在训练过程中可以自动更新。
2. torch.nn.Parameter
的作用
torch.nn.Parameter
是 PyTorch 中的一种特殊张量(Tensor
),它表明该张量是模型的可学习参数。换句话说,将一个张量包装成 Parameter
,意味着 PyTorch 知道它是需要在训练过程中被优化的。
- 普通张量 vs
Parameter
:
如果你直接使用torch.Tensor
,它是不会被自动优化的;而torch.nn.Parameter
会被自动添加到模型的参数列表中,并参与反向传播的计算。
例如:
self.W_ih = torch.nn.Parameter(torch.randn(hidden_size, input_size))
这段代码会将 W_ih
包装成 Parameter
,使其在训练时被优化器更新。
3. 为什么使用 torch.nn.ParameterList
包装多个参数
在你的代码中,每一层 RNN 的权重 W_ih
, W_hh
和偏置 b_h
都被存储在列表中,因为你有 多层 RNN。每层都有一组权重和偏置参数,所以需要一个列表来存储多层的参数。ParameterList
提供了一种优雅的方式来管理这些多层参数。
示例:
self.W_ih = torch.nn.ParameterList([torch.nn.Parameter(w) for w in self.W_ih])
这行代码的意思是:将每一层的权重(W_ih
)转换为 torch.nn.Parameter
对象,然后使用 ParameterList
存储这些权重。这样,每层的权重参数会被正确注册,并在训练时被更新。
4. 什么时候使用到这些参数
这些权重和偏置参数会在 前向传播(forward pass) 和 反向传播(backward pass) 中使用。它们分别是:
self.W_ih
: 输入到隐藏层的权重矩阵。self.W_hh
: 隐藏层到隐藏层的递归权重矩阵。self.b_h
: 隐藏层的偏置项。
在每个时间步,RNN 会使用这些参数来计算输入和隐藏状态之间的关系,并更新隐藏状态。
在每次训练迭代中,经过前向传播后,会使用反向传播来计算这些参数的梯度。由于它们被包装成 torch.nn.Parameter
,PyTorch 会自动计算梯度,并在优化过程中更新它们。
5. 总结
torch.nn.ParameterList
是一个特殊的容器,用来存储和管理多个Parameter
对象,确保它们能够在模型的训练过程中被自动识别和优化。torch.nn.Parameter
是将张量标记为模型的可学习参数,以便 PyTorch 能够追踪和更新它们。- 将多个层的权重和偏置使用
ParameterList
来存储,可以保证多层 RNN 的每一层参数都能被有效管理和优化。
通过这种设计,你可以灵活地管理每一层 RNN 的参数,并确保它们在训练过程中被正确优化。