torch.nn.LSTM()

链接1:
单个RNN单元可以使用torch.nn.RNNCell(), LSTMCell(), GRUCell() ——使用for循环来处理循环神经网络(时间维度)


可以直接调用torch.nn.RNN(), LSTM(), GRU()

rnn = nn.LSTM(10, 20, 2)  
# 初始化LSTM, 输入x的特征数=10, 输出隐藏状态的特征数=20, LSTM的层数=2

input = torch.randn(5, 3, 10)  
# 字段长度/时间步总数=5, batch_size=3, 输入x的特征数=10

h0 = torch.randn(2, 3, 20)  
# 初始隐藏状态特征: 层数*方向数(双向LSTM时=2)=2, batch_size=3, 输出隐藏状态的特征数=20. 若没定义默认为0

c0 = torch.randn(2, 3, 20)  
# 初始细胞状态特征: 同上

output, (hn, cn) = rnn(input, (h0, c0))  
# 返回的output是LSTM最后一层(这里是第二层)所有的字/时间步 的输出特征, output的shape=(5, 3, 20); 
hn和cn是最后一个字/时间步 的隐藏状态特征和细胞状态特征,shape和h0,c0一样.

链接2:
模型的参数input_size与模型的输入input是不同的。

参数:

input_size——输入数据的特征维数,通常就是embedding_dim(词向量的维度)
hidden_size——LSTM中隐层的维度
num_layers——循环神经网络的层数

bias——用不用偏置,default=True;

False,the layer does not use bias weights b_ih and b_hh.

batch_first——这个要注意,通常我们输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换

dropout——默认是0,代表不用dropout

If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0

bidirectional默认是false,代表不用双向LSTM

输入

输入数据包括input,(h_0,c_0):
input就是shape==(seq_length,batch_size,input_size)的张量

h_0的shape==(num_layers * num_directions, batch, hidden_size)的张量,它包含了在当前这个batch_size中每个句子的初始隐藏状态,num_layers就是LSTM的层数,如果bidirectional=True,num_directions=2,否则就是1,表示只有一个方向,

c_0和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始细胞状态。h_0,c_0如果不提供,那么默认是0

输出

输出数据包括output,(h_n,c_n):
output的shape==(seq_length, batch_size,num_directions * hidden_size),
它包含的LSTM的最后一层的输出特征(h_t),t是batch_size中每个句子的长度.
h_n.shape==(num_directions * num_layers, batch, hidden_size)
c_n.shape==h_n.shape
h_n包含的是句子的最后一个单词的隐藏状态,c_n包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关。
output[-1]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息


补充:b_ih and b_hh看下图,可知,是指4个门的bias

i t i_t it f t f_t ft g t g_t gt o t o_t otare the input, forget, cell, and output gates,

b_ih 是纵向(网络层)
b_hh是横向(时序)
在这里插入图片描述
在这里插入图片描述


实战:训练网络帮我我们标注词性。
目前还没时间和能力去看。。。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值