输入:
input: [batch, input_size]
hidden: [batch, hidden_size]
输出:
h′:[batch,hidden_size]
参数:
GRUCell.weight_ih: [3 x hidden_size, input_size]
GRUCell.weight_hh: [3 x hidden_size, hidden_size]
GRUCell.bias_ih: [3 x hidden_size]
GRUCell.bias_hh: [3 x hidden_size]
from torch.nn import GRUCell
gru_cell = GRUCell(5,10)
input_ = torch.randn(2,5)
h_0 = torch.randn(2,10)
h1 = gru_cell(input_,h_0)
'''
forward输入:
input:[batch_size,input_dim]
h(t-1):[batch_size,hidden_size]
forward输出
h(t):[batch_size,hidden_size]
'''
print(h1.shape)
torch.Size([2, 10])