class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bilstm = torch.nn.LSTM(4,10,batch_first=False,bidirectional=True)
def forward(self,input):
rnn_out,_ = self.bilstm(input) #输出形状 [length:3,batch:2,hidden:20]
encoding = torch.cat([rnn_out[0], rnn_out[-1]], dim=1) ## 选择双向hidden的最后一个向量,并把他们拼接成一个向量
#cat 拼接向量 dim=0是按第一个维度拼接,dim=1是按第二个维度拼接,dim=2 以此类推
return encoding
net = Net()
input = torch.rand(6,2,4)
input
out = net(input)
out
输出:
tensor([[[0.8778, 0.0818, 0.6875, 0.8859],
[0.5734, 0.2775, 0.3590, 0.2174]],
[[0.9982, 0.5425, 0.3168, 0.1968],
[0.3750, 0.4696, 0.5498, 0.0981]],
[[0.1177, 0.0353, 0.0145, 0.5931],
[0.9969, 0.8779, 0.0992, 0.0971]],
[[0.2751, 0.5499, 0.0971, 0.6711],
[0.1580, 0.5890, 0.0929, 0.2064]],
[[0.5682, 0.0360, 0.0856, 0.8951],
[0.1791, 0.3881, 0.4138, 0.0593]],
[[0.8850, 0.2704, 0.5389, 0.3144],
[0.8112, 0.8874, 0.0174, 0.1409]]])
tensor([[-0.0343, 0.1127, 0.0279, 0.0927, 0.0789, 0.0817, -0.0883, 0.1394,
-0.0650, 0.0198, -0.2320, -0.1571, -0.2744, 0.0421, -0.0620, -0.1478,
-0.0336, -0.1328, 0.1629, -0.4289, -0.1029, 0.1805, 0.0289, 0.1386,
0.1701, 0.1697, -0.1425, 0.1867, -0.0695, -0.0141, -0.1513, -0.0761,
-0.1537, 0.0194, -0.0216, -0.0107, 0.0298, -0.1285, 0.0615, -0.1595],
[-0.0676, 0.0644, -0.0016, 0.0671, 0.0838, 0.0723, -0.0314, 0.0668,
-0.0258, -0.0100, -0.2025, -0.1248, -0.2316, -0.0192, -0.1111, -0.0949,
0.0649, -0.1628, 0.0912, -0.3886, -0.1119, 0.0934, 0.0270, 0.0647,
0.1424, 0.1345, -0.0460, 0.1188, -0.0534, -0.0672, -0.1266, -0.0497,
-0.1637, -0.0041, -0.0519, 0.0089, 0.0486, -0.1394, -0.0221, -0.1710]],
grad_fn=<CatBackward>)