学习笔记——torch.nn.RNN()

1. 调用方式

官方文档RNN — PyTorch 1.13 documentation

用于实现RNN层,并可通过传入参数实现多层堆叠(深度循环神经网络)、双向传播(双向循环神经网络):

示例(包含输入输出格式):

import torch
import torch.nn as nn
myrnn = nn.RNN(4,3,2,batch_first=True) #input_size,hidden_size,num_layers
print("myrnn:", myrnn)
input = torch.randn(2,3,4) #输入数据集格式(batch_size, sequence_length, input_size(已限定为4))
print("input:", input)
output, h_n = myrnn(input) #output为每个时刻的隐藏状态,格式为(batch_size,sequence_length,hidden_size);h_n为最后时刻的隐藏状态,格式为(num_layers,batch_size,hidden_size)
print("output:", output)
print("h_n:", h_n)

输出:

* hidden_size类似于全连接网络的结点个数

2. 关于batch_first

输入数据集格式(batch_first默认False):

其中,N=batch_size批量大小,L=sequence_length序列长度,H=input_size输入尺寸。

默认顺序为(sequence_length,batch_size,input_size),与通常batch_size在第一维度有所不同,原因参考MultiHeadAttension源码解析——batch_first参数含义_coder1479的博客-CSDN博客读PyTorch源码学习RNN(1) - 知乎可知:

“由于RNN是序列模型,只有 t1 时刻计算完成,才能进入 t2 时刻,而"batch"就体现在每个时刻 ti 的计算过程中,图中 t1 时刻将["A", "D"]作为当前时刻的batch数据,t2 时刻将["B", "E"]作为当前时刻的batch数据,可想而知,"A"与"D"在内存中相邻比"A"与"B"相邻更合理,这样取数据时才更高效。” 

实际使用中可将batch_first设置为True,按照(batch_size,sequence_length,input_size)顺序传入参数,此时函数会自动将其转换成默认的顺序(sequence_length,batch_size,input_size),并且在输出结果的时候,再转换回来。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值