文章目录
nn.RNNBase
简介
nn.RNNBase
是 PyTorch 中所有循环神经网络
(RNN)类的基础类。这个类提供了一些基本的功能和接口,供具体的 RNN 实现(如 nn.RNN
、nn.LSTM
、nn.GRU
)继承和使用。
参数
nn.RNNBase
的构造函数包含以下几个主要参数:
mode
: 指定 RNN 的类型,'RNN_TANH'
或'RNN_RELU'
。input_size
: 每个时间步的输入特征维度。hidden_size
: 隐藏状态的特征维度。num_layers
: RNN 的层数,默认为 1。bias
: 如果为 False,RNN 层不会使用偏置,默认为 True。batch_first
: 如果为 True,输入和输出的张量格式为 (batch, seq, feature),默认为 False。dropout
: 除最后一层外的每个 RNN 层的 dropout 概率,默认为 0。bidirectional
: 如果为 True,则使用双向 RNN,默认为 False。