torch.nn.LSTM()函数维度详解

参考:

https://pytorch.org/docs/stable/nn.html#lstm

https://blog.csdn.net/yangyang_yangqi/article/details/84585998

https://zhuanlan.zhihu.com/p/39191116

https://blog.csdn.net/m0_37586991/article/details/88561746

简单应用

直接看官网的文档太崩溃了,先从使用说起,上代码

import torch
import torch.nn as nn

lstm = nn.LSTM(10, 20, 2)  # 实例化,         括号里的参数(input_size,hidden_size,num_layers)

x = torch.randn(5, 3, 10)  # 准备输入张量,    括号里的参数(seq_len,batch,input_size)
h0 = torch.randn(2, 3, 20)  # hidden初始状态,括号里的参数(num_layers*num_directions, batch, hidden_size)
c0 = torch.randn(2, 3, 20)  # cell初始状态,  括号里的参数(num_layers * num_directions, batch, hidden_size)

output, (hn, cn) = lstm(x, (h0, c0))

print(output.size())
print(hn.size())
print(cn.size())

输出

torch.Size([5, 3, 20])
torch.Size([2, 3, 20])
torch.Size([2, 3, 20])

代码中写出了lstm三个输入的形状,这里再明确一下,先不讨论batch_first与否

默认情况batch_first=False

输入数据格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)

输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)
 

用颜色表示更清晰

举例分析

再着重讲一下输入的数据x

x = torch.randn(5, 3, 10)

可以理解为:1个batch中有3个句子,每个句子5个单词,每个单词用10维的向量表示;而句子的长度是不一样的,所以seq_len可长可短,这也是LSTM可以解决长短序列的特殊之处。只有seq_len这一参数是可变的。

笔者主要是用来复现一个单目视觉里程计,对应来看:每帧图像对应着上述的单词,每个样本中有多少帧图像seq_len就是多少,图像feature map拉直后就是embedding_size,也就是input_size,最后的需要输出的相机Pose,先不考虑后续的全连接层,希望每帧图像输出的6维的张量,因此hidden_size设置为6(单向lstm)

再用参考链接里的举个例子:
对句子进行LSTM操作

假设有100个句子(sequence),每个句子里有7个词,batch_size=64,embedding_size=300

此时,各个参数为:
input_size=embedding_size=300
batch=batch_size=64
seq_len=7

另外设置hidden_size=100, num_layers=1

batch_first的情况

batch_first: 输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;

结构推导

未完

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值