LSTM输入参数有input_size, hidden_size, num_layers, bidrectional
input_size为输入序列维度的最后一维([8,56,768]),输入的input_size填写768
hidden_size可以理解为输出的序列的最后一维
num_layers表示LSTM堆叠的层数
bidrectional为布尔类型,True时表示使用双向的LSTM,False表示为单向的LSTM
例如:我想 [8,56,768] 经过LSTM后维度不变,则需要设置:
a=torch.randn(8,56,768)
lstm=torch.nn.LSTM(768,384,10,bidirectional=True, batch_first=True)
out,(h,c)=lstm(a)
print("out:",out.size())
此处为什么设置hidden_size等于384,因为使用的是双向的LSTM,hidden_size输出时会翻倍
如果是单向的LSTM则设置hidden_size为768(谨记录个人使用过程)