🚀 PyTorch API 之 RNN
⚽️ 优点:
- 模型大小与序列长度无关
- 计算量与序列长度呈线性增长
⚽️ 缺点:
- 串行计算比较慢
- 无法获取太长的历史信息
🚀 torch官方torch.nn.RNN()
-
⚽️ 每层的计算公式为:
h t = t a n h ( x t W i h T + b i h + h t − 1 W h h T + b h h ) h_t=tanh(x_tW_{ih}^T+b_{ih}+h_{t−1}W_{hh}^T+b_{hh}) ht=tanh(xtWihT+bih+ht−1WhhT+bhh) -
⚽️ 实例化
torch.nn.RNN()
的参数:input_size
– 输入的特征维度hidden_size
– 隐藏层的特征维度num_layers
–RNN
的层数,默认为1nonlinearity
– 非线性激活函数.'tanh'
or'relu'
. Default:'tanh'
bias
– 是否添加偏置. Default:True
batch_first
–batch
维度是否在第一维,如果传入True
, 输入的shape应该是 (batch, seq, feature) . Default:False
dropout
– 神经元参数丢弃的概率. Default: 0bidirectional
– 是否使用双向RNN. Default:False
-
⚽️
torch.nn.RNN()
输入和输出:-
输入:
input
: [L, N, Hin] whenbatch_first=False
or [N, L, Hin] whenbatch_first=True
h_0
: Defaults to zeros if not provided. -
输出:
output
:[L, N, D*Hout] whenbatch_first=False
or [N, L,D*Hout] whenbatch_first=True
. 双向RNN中D=2.h_n
:最终的隐藏状态.
-
-
⚽️
torch.nn.RNN()
使用示例:>>> rnn = nn.RNN(input_size=5, hidden_size=10, batch_first=True) >>> # inputs的shape [batch timestep input_size] >>> inputs = torch.randn(2, 5, 5) >>> outputs, h_n = rnn(inputs) >>> # outputs的shape [batch timestep hidden_size] >>> outputs.shape torch.Size([2, 5, 10]) >>> h_n.shape torch.Size([1, 2, 10])
🚀 手写一个单向单层的RNN
的前向过程
h t = t a n h ( x t W i h T + b i h + h t − 1 W h h T + b h h ) h_t=tanh(x_tW_{ih}^T+b_{ih}+h_{t−1}W_{hh}^T+b_{hh}) ht=tanh(xtWihT+bih+ht−1WhhT+bhh)
class CustomRNN:
def __init__(self, rnn_torch:nn.RNN):
# 获取实例化rnn的参数值,便于对比结果
self.W_i_h = rnn_torch.weight_ih_l0 # [hidden_size, input_size]
self.w_h_h = rnn_torch.weight_hh_l0 # [hidden_size, hidden_size]
self.b_i_h = rnn_torch.bias_ih_l0 # hidden_size
self.b_h_h = rnn_torch.bias_hh_l0 # hidden_size
def __call__(self, inputs, h_prev):
batch, timestep, input_size = inputs.shape
hidden_size = self.w_h_h.shape[0]
# 初始化结果集
h_out = torch.ones(batch, timestep, hidden_size)
# 计算结果
for i in range(timestep):
x = inputs[:, i, :].unsqueeze(-1) # 获取当前时刻的输入特征 [batch, input_size, 1]
batch_W_ih = self.W_i_h.unsqueeze(0).tile(batch, 1, 1) # [batch, hidden_size, input_size]
batch_w_hh = self.w_h_h.unsqueeze(0).tile(batch, 1, 1) # [batch, hidden_size, hidden_size]
w_times_x = torch.bmm(batch_W_ih, x).squeeze(-1) # [batch, hidden_size]
w_times_h = torch.bmm(batch_w_hh, h_prev.unsqueeze(-1)).squeeze(-1) # [batch, hidden_size]
h_prev = torch.tanh(w_times_x + self.b_i_h + w_times_h + self.b_h_h)
h_out[:, i, :] = h_prev
return h_out, h_prev.unsqueeze(0)
# 初始化参数
batch = 2
input_size = 2
hidden_size = 4
timestep = 4
# 构建输入和h_0
inputs = torch.randn(batch, timestep, input_size)
h_0 = torch.zeros(batch, hidden_size)
# PyTorch API
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
out_torch, h_n_torch = rnn(inputs, h_0.unsqueeze(0))
# 手动实现前向过程
custom_rnn = CustomRNN(rnn)
out_custom, h_n_custom = custom_rnn(inputs, h_0)
# 比较输出是否一致
result = torch.allclose(out_torch, out_custom)