pytorch学习笔记(7):RNN和LSTM实现分类和回归

本文介绍了如何在PyTorch中使用RNN和LSTM实现分类和回归任务。通过RNN建立的回归器用于预测sin曲线,而LSTM则在MNIST数据集上实现了一个分类器。文章详细解释了RNN和LSTM的原理、参数设置以及在PyTorch中的实现步骤。
摘要由CSDN通过智能技术生成

参考文档:https://mp.weixin.qq.com/s/0DArJ4L9jXTQr0dWT-350Q

在第三篇文章中,我们介绍了 pytorch 中的一些常见网络层。但是这些网络层都是在 CNN 中比较常见的一些层,关于深度学习,我们肯定最了解的两个知识点就是 CNN 和 RNN。那么如何实现一个 RNN 呢?这篇文章我们用 RNN 实现一个分类器和一个回归器。

本文需要你最好对 RNN 相关的知识有一个初步的认识,然后我会尽可能的让你明白在 pytorch 中是如何去实现这一点的。

1、pytorch提供了哪些RNN?

如果我们对 RNN 有所了解,就会知道 RNN 有很多变种,就像 CNN 也有很多变种一样。只要带了 recurrent 的功能,就都属于 RNN 的范畴。那么在 pytorch 中提供了哪些 RNN 呢?
在这里插入图片描述
上面这幅图是 pytorch 源码中的结构,可以看到除了一个 RNNBase() 类,下面还有 RNN,LSTM,GRU 分别继承了 RNNBase() 类,实现了三个 RNN 子类。这三个也就是 pytorch 提供的 RNN 类型。

今天的文章,我们分别通过源码中的 doc 和一些介绍来了解 RNN 和 LSTM,然后分别用它们实现一个回归器和一个分类。关于 GRU 的部分,就留给大家自己去展开啦。

2、RNN,以及实现一个回归器

这一部分,我们先从 RNN 开始进行介绍,分别简单介绍一下 RNN 的原理,在 pytorch 中使用它的一些参数要求,最后是一个回归器,用 sin 曲线作为输入,cos 曲线作为 label,判断函数的拟合能力。

2.1、简单介绍RNN

首先看一下 RNN 的内容,常见的介绍 RNN 的文章中都会有这样一幅图:
在这里插入图片描述
X 0 , X 1 X_0,X_1 X0X1 等等分别是输入序列的一个维度上的数据, X 0 X_0 X0 首先传进去,生成 h 0 h_0 h0 作为第一个隐状态。然后 h 0 h_0 h0 X 1 X_1 X1 一起作为下一个时间序列上的输入,它们的输出再和 X 2 X_2 X2 作为下下个时间序列的输入,以此类推。

具体的细节我们就不展开讲了,默认大家对理论层面已经有了了解。在 pytorch 的源码 doc 中也给出了,对下面的公式进行计算: h t = t a n h ( W i h x t + b i h + W h h h ( t − 1 ) + b h h ) h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{(t-1)}+b_{hh}) ht=tanh(Wihxt+bih+Whhh(t1)+bhh)这个式子也是对于 RNN 的常见描述。 W i h W_{ih} Wih 表示对输入数据进行处理的权重,而 W h h W_{hh} Whh 则表示对上一个时间序列的隐状态进行处理的权重。 b i h b_{ih}

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值