rnn 与gru区别
两者网络接口相同,只需要在网络定义里替换一下相互名字即可
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
其它输入数据和输出接口保持不变
rnn与lstm区别
rnn与lstm网络接口定义一样,只是换接口名称,但是lstm前向增加了cell_state的初始化和输出
网络结构
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
前向传播
x (batch_size, time_step, feature_size)
- rnn
def forward(self, x):
#Forward loop
#x.shape=(batch_size, time_step, input_size)
#h.shape=(num_layers, batch_size,hidden_size)
#out.shape=(batch_size, time_step, hidden_size)
h0 = torch.zeros(self.num_l