import torch
x = torch.ones(1,1,1)
b0 = torch.zeros(2,1,1)
# out: 3,1,5
rnn = torch.nn.RNN(1,1,num_layers=2,batch_first=False)
wb4 = list(rnn.parameters())
o,h =rnn(x,b0)
# print(o.shape,h.shape)
print(o)
print(h)
print('++++++')
whx = wb4[0]
print('wb4',wb4)
print('++++++')
bw = wb4[2]
whh = wb4[1]
bh = wb4[3]
rst1 = x@whx.t()
print(rst1)
rst2 = bw
print(rst2)
rst3 = b0@whh.t()
print(rst3)
rst4 = bh
print(rst4)
print('+++++++')
print(torch.tanh(rst1+rst2+rst3+rst4))
pytorch下rnn的一些思考
最新推荐文章于 2024-04-07 09:12:00 发布