基础概念 _ 循环神经网络RNN

本文深入介绍了循环神经网络(RNN)的诞生原因、基本架构和与前馈神经网络的区别。RNN因其记忆功能在处理时序数据如视频、语音时表现出优势,但并行计算能力受限。通过PyTorch展示了RNN模型的实现,以生成滞后正弦函数为例,详细解释了训练过程和模型效果。
摘要由CSDN通过智能技术生成

1 前言

不少同学让我出一起RNN的教学文章,学长为了满足各位同学的要求,写下此文

1.1 RNN诞生的原因

在普通的前馈神经网络(如多层感知机MLP,卷积神经网络CNN)中,每次的输入都是独立的,即网络的输出依赖且仅依赖于当前输入,与过去一段时间内网络的输出无关。但是在现实生活中,许多系统的输出不仅依赖于当前输入,还与过去一段时间内系统的输出有关,即需要网络保留一定的记忆功能,这就给前馈神经网络提出了巨大的挑战。除此之外,前馈神经网络难以处理时序数据,比如视频、语音等,因为时序数据的序列长度一般是不固定的,而前馈神经网络要求输入、输出的维度都是固定的,不能任意改变。出于这两方面的需求,循环神经网络RNN应运而生。

1.2 简介

循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络。在循环神经网络中,神经元既可以如同前馈神经网络中神经元那般从其他神经元那里接受信息,也可以接收自身以前的信息。且和前馈神经网络相比,循环神经网络更加符合生物神经网络的结构。

1.3 与前馈神经网络的差异

就功能层面和学习性质而言,循环神经网络也有异于前馈神经网络。前馈神经网络多用于回归(Regression)和分类(Classification),属于监督学习(Supervised learning);而循环神经网络多用于回归(Regression)和生成(Generation),属于无监督学习(Unsupervised learning)。

2 网络架构

在这里插入图片描述
上图所示为RNN的一个神经元模型。给定一个序列长度为T的输入序列

在这里插入图片描述

在t时刻将该序列中的第t个元素xt送进网络,网络会结合本次输入xt及上一次的 “ 记忆 ” ht-1,通过一个非线性激活函数f()(该函数通常为tanh或relu),产生一个中间值ht(学名为隐状态,Hidden States),该值即为本次要保留的记忆。本次网络的输出为ht的线形变换,即g(ht),其中g()为简单的线性函数。

由于循环神经网络具有时序性,因此其网络架构可以从空间和时间两方面进行了解。

为方便理解,特以此情景举例:假设目前正在训练一个RNN,用于生成视频。训练用的train_data为1000段等长度的视频,划分的batch_size为10,即每次送给网络10段视频。假设每段视频都有50帧,每帧的分辨率为2X2。在此训练过程中,不需要外界给label,将某帧图片input进RNN时,下一帧图片即为label。

2.1 空间角度

在本例中,一个序列长度为T的输入序列即一段50帧视频,其中的x1,x2,x3…x50分别对应着第一帧图片、第二帧图片、第三帧图片…到第五十帧图片。随着时间的推移,依次将每帧图片送进网络,在每个时刻,只送进一帧图片,同时生成一帧图片。该网络架构如下图所示。就如普通的前馈神经网络一般,除了输入输出层的维度固定外,隐藏层的维度及层数都是可以自主设计的。这里假设只有一个隐藏层,该层有5个神经元。需要注意的是,每个具有记忆功能的神经元中所存储的隐状态h,不仅参与本神经元的下次输入,还同时参与本层其他具有记忆功能神经元的下次输入。

在这里插入图片描述

2.2 时间角度

从空间角度观察整个网络,可将网络视为1个4X5X4的循环神经网络RNN,其中隐藏层的5个神经元是具有记忆功能的;从时间角度展开整个网络,可将网络视为50个4X5X4的多层感知机MLP,每个MLP隐层神经元不仅向该MLP的下一层输出,同时还向下一个MLP中与之层数对应的所有神经元输出(下图为求清晰表示,将其化简为一对一,但实质上是一对多),该输出即为隐状态h。由于隐状态需要在MLP之间从前向后传递,因此这50个MLP只能依次运算,不能并行运算。

在这里插入图片描述

循环神经网络独特的神经元结构赋予其记忆功能,但有得有失,该结构也造成其在处理时序数据时不能进行并行运算。目前推动算法进步的一大助力就是算力,而RNN却无法充分利用算力,这也是一部分人不太看好其前景的原因。

3 pytorch 实现RNN

对一段正弦函数进行离散采样作为输入,利用RNN生成滞后的正弦函数采样值。实现环境:Colab。

引入必要头文件。

1 import torch.nn as nn
2 import torch
3 import numpy as np
4 import matplotlib.pyplot as plt
5 %matplotlib inline

采样获取input及label,并可视化。

# 制定画布大小
plt.figure(figsize=(8, 5))

# 每个batch中数据个数
num = 20

# 生成数据
time_steps = np.linspace(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值