Pytorch nn.RNN()解析

RNN基本结构与nn.RNN()参数介绍可参考:
参数介绍
官方文档
以下代码对 nn.RNN() 的简单应用进行了注解介绍

import torch
import torch.nn as nn
import torch.functional as F


# 单层RNN,输入x特征为10,输出特征为20, 两层堆叠
rnn = nn.RNN(10, 20, 2)
# 随机构建输入,这里假设每句话只有10个单词,共3句话(即一个batch 3句话),每个单词被embedding为10维向量
inputs = torch.rand(10, 3, 10)
# 随机构建h0, 因为只有2层单向,所以参数为2
h_0 = torch.zeros(2, 3, 20)
# 输出结果,其中output为每一个输入(每一个单词)对应的最终输出值;h_0为最后一个单词在各层的输出值(此处有两层)
output, h_n = rnn(inputs, h_0)


# torch.Tensor
print(type(output))
# (10, 3, 20)
print(output.shape)
print(output)
print("=========================")
# torch.Tensor
print(type(h_n))
# (2,3,20)
print(h_n.shape)
print(h_n)
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值