应用
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
API
CLASS torch.nn.RNN(*args, **kwargs)
h
t
=
t
a
n
h
(
W
i
h
x
t
+
b
i
h
+
W
h
h
h
t
−
1
+
b
h
h
h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_{hh}
ht=tanh(Wihxt+bih+Whhht−1+bhh
h
t
h_t
ht:the hidden state at time t
x
t
x_t
xt:the input at time t
h
t
−
1
h_{t-1}
ht−1:the hidden state of the previous layer at time t-1
如果nonlinearity
是relu
,则会替换tanh
类
参数 | 描述 |
---|---|
input_size | The number of expected features in the input x |
hidden_size | The number of features in the hidden state h |
num_layers | Number of recurrent layers,Default: 1 |
nonlinearity | The non-linearity to use.Default: ‘tanh’ |
bias | If False, then the layer does not use bias weights b_ih and b_hh. Default: True |
batch_first | If True, then the input and output tensors are provided as (batch, seq, feature). Default: False |
bidirectional | If True, becomes a bidirectional RNN. Default: False |
input_size:是RNN的维度,注意不是句子或序列的长度,而是句子的一个词,或序列的一个元素的维度。比如,词向量的维度。比如说NLP中你需要把一个单词输入到RNN中,这个单词的编码是300维的,那么这个input_size就是300.
hidden_size:每个RNN的节点实际上就是一个BP网络,包含输入层,隐含层,输出层。这里就是指隐藏层的节点个数。
num_layers:如果num_layer=2的话,表示两个RNN堆叠在一起。
参考:https://www.cnblogs.com/dhName/p/11760610.html
对象
输入:
参数 | 描述 |
---|---|
input of shape (seq_len, batch, input_size) | The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() or torch.nn.utils.rnn.pack_sequence() for details. |
h_0 of shape (num_layers * num_directions, batch, hidden_size) | tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. |
输出:
参数 | 描述 |
---|---|
output of shape (seq_len, batch, num_directions * hidden_size): | |
h_n of shape (num_layers * num_directions, batch, hidden_size): |
参考:
https://www.icode9.com/content-4-622959.html
https://zhuanlan.zhihu.com/p/59772104
只是batch长度要求相同,但不同batch则不需要相同?
https://zhuanlan.zhihu.com/p/97378498
https://www.cnblogs.com/lindaxin/p/8052043.html
https://www.jianshu.com/p/f5b816750839
https://www.jianshu.com/p/efe045c24a93
https://zhuanlan.zhihu.com/p/161972223
https://zhuanlan.zhihu.com/p/34418001?edition=yidianzixun&utm_source=yidianzixun&yidian_docid=0IVwLf60
https://www.cnblogs.com/jiangkejie/p/13141664.html
https://zhuanlan.zhihu.com/p/64527432