[PyTorch] rnn,lstm,gru中输入输出维度

[PyTorch] rnn,lstm,gru中输入输出维度

本文中的RNN泛指LSTM,GRU等等
CNN中和RNN中batchSize的默认位置是不同的。
CNN中:batchsize的位置是position 0.
RNN中:batchsize的位置是position 1.

在RNN中输入数据格式:
对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell(),它只接受序列中的单步输入,必须显式的传入隐藏状态。torch.nn.RNN()可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。
输入大小是三维tensor[seq_len,batch_size,input_dim]
input_dim是输入的维度,比如是128
batch_size是一次往RNN输入句子的数目,比如是5。
seq_len是一个句子的最大长度,比如15
所以千万注意,RNN输入的是序列,一次把批次的所有句子都输入了,得到的ouptut和hidden都是这个批次的所有的输出和隐藏状态,维度也是三维。
**可以理解为现在一共有batch_size个独立的RNN组件,RNN的输入维度是input_dim,总共输入seq_len个时间步,则每个时间步输入到这个整个RNN模块的维度是[batch_size,input_dim]

在这里插入图片描述
问题1:这里out、ht的size是多少呢?
回答:out:6 * 3 * 10, ht: 2 * 3 * 10,out的输出维度[seq_len,batch_size,output_dim],ht的维度[num_layers * num_directions, batch, hidden_size],如果是单向单层的RNN那么一个句子只有一个hidden。
问题2:out[-1]和ht[-1]是否相等?
回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元

作者:MapReducer
链接:https://www.jianshu.com/p/b942e65cb0a3
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
作者:MapReducer
链接:https://www.jianshu.com/p/b942e65cb0a3
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值