RNNBasee
LSTM
class is inherit RNNBase
, which located in torch.nn.modules.rnn.RNNBase
class RNNBase(Module):
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional']
mode: str
input_size: int
hidden_size: int
num_layers: int
bias: bool
batch_first: bool
dropout: float
bidirectional: bool
__constants__
is torch scripts, see stack overflow
construction function
First of all, let’s see construction function , in other words, __init__
you can see, torch consider bi-direction stacked LSTM as base RNN implemention(RNNBase
class). This architecture is really generic !
if mode == 'LSTM':
gate_size = 4 * hidden_size
elif mode == 'GRU':
gate_size =<