小白撸代码
import torch.nn as nn
import torch
from torch.autograd import Variable
rnn = nn.GRUCell(10,20)#规定输入维度10 隐藏维度为20
#包含输入特征的Tensor
#6大行矩阵,没大行为3行10列 10是GRUCell的10
input = Variable(torch.randn(6,3,10))
#保存着batch中每个元素的初始化隐状态的Tensor
#隐藏层为3行20列;3要和input中的3保持高度一致20为GRUCell的20
#(说一下,之前的LSTM、LSTMCell、GRU中的隐藏层
# 和细胞层的行都要高度一致)
hx = Variable(torch.randn(3,20))
# print(hx,"hx1###########")
#输出
output = []
#将输出格式保持为(6,3,20)20为要和GRUCell中20要一致
for i in range(6):#6 不是一定为6 可以是其他值
hx = rnn(input[i] , hx)#保存着RNN下一个时刻的隐状态。
output.append(hx)
# print(hx,"hx2###############################")
# print(output)