点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来自 | 知乎
作者 | Lucas
地址 | https://zhuanlan.zhihu.com/p/85995376
RNN 扫盲:循环神经网络解读及其 PyTorch 应用实现
循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,也就是说隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。简单来说,设计 RNN 就是为了处理序列数据。如果说 CNN 是对人类视觉的仿真,那 RNN 不妨先看作是对人类记忆能力的模拟。
为什么需要 RNN? 和 CNN 的主要区别?
CNN 相当于人类视觉,是没有记忆能力的,没有办法根据以前的记忆来处理新任务。而 RNN 是基于人的记忆的想法,期望网络能够机主前面出现的特征,根据特征完成下游任务。
CNN 需要固定长度的输入、输出,RNN 的输入和输出可以是不定长且不等长的
CNN 只有 one-to-one 一种结构,而 RNN 有多种结构。
结构组成
一个简单 RNN 由三个部分组成,输入层、隐藏层,输出层(废话)如果我们把上面的图展开,循环神经网络也可以画成下面这个样子:
![5f6fa9505953968188d06f036061a58f.png](https://i-blog.csdnimg.cn/blog_migrate/970e740f1780bc2a8968cba8b4e43c1d.png)
为什么循环神经网络可以往前看任意多个输入值呢?
看啊这个是输出层 o和隐藏层 s 的计算公式
如果把公式 2 一直往公式 1 里带,则有:
记忆能力
该模型具有一定的记忆能力,能够按时序依次处理任意长度的信息。前面的输入对未来产生影响。什么意思呢下图所示。当我们将“ What time is it ?" 每个词进入神经网络后都会对下一个词产生影响,
![ef8639eb07c30c8fb35b58a7192e6787.png](https://i-blog.csdnimg.cn/blog_migrate/8e3ee9c15ce6927083a78dd5cc907e82.png)
缺点:梯度消失和梯度爆炸
![381c62820647521a6a27d9b7687243dc.png](https://i-blog.csdnimg.cn/blog_migrate/7e53eaed3674ee3ade1108ad8c811f93.png)
通过上面的例子,我们已经发现,短期的记忆影响较大(如橙色区域),但是长期的记忆影响就很小(如黑色和绿色区域),这就是 RNN 存在的短期记忆问题。
莫烦 Python 这里讲解的非常生动形象:
‘我今天要做红烧排骨, 首先要准备排骨, 然后…., 最后美味的一道菜就出锅了’。现在请 RNN 来分析, 我今天做的到底是什么菜呢. RNN可能会给出“辣子鸡”这个答案. 由于判断失误, RNN就要开始学习 这个长序列 X 和 ‘红烧排骨’ 的关系 , 而RNN需要的关键信息 ”红烧排骨”却出现在句子开头。
![51cdddca62480cd7bd89e0a614d3864a.png](https://i-blog.csdnimg.cn/blog_migrate/6d8c1bb0d0df044598b62493f98363a7.png)
![722879c3318cbd5ca930d25e0a2a45f1.png](https://i-blog.csdnimg.cn/blog_migrate/cd7a8fe33b76dfd1cbccafc8dd4c3eea.png)
红烧排骨这个信息原的记忆要进过长途跋涉才能抵达最后一个时间点. 然后我们得到误差, 而且在 反向传递 得到的误差的时候, 他在每一步都会 乘以一个自己的参数 W. 如果这个 W 是一个小于1 的数, 比如0.9. 这个0.9 不断乘以误差, 误差传到初始时间点也会是一个接近于零的数, 所以对于初始时刻, 误差相当于就消失了. 我们把这个问题叫做梯度消失或者梯度弥散 Gradient vanishing. 反之如果 W 是一个大于1 的数, 比如1.1 不断累乘, 则到最后变成了无穷大的数, RNN被这无穷大的数撑死了, 这种情况我们叫做梯度爆炸, 这就是普通 RNN 没有办法回忆起久远记忆的原因.
在此感谢
@莫烦
前辈关于机器学习机器学习基础知识的讲解,让晚辈能够迅速成长,知识分享是人的一种高贵品质。我由衷感谢您!
基础模型的 PyTorch 实现
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
def forward(self, input, hidden):
# 将input和之前的网络中的隐藏层参数合并。
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined) # 计算隐藏层参数
output = self.i2o(combined) # 计算网络输出的结果
return output, hidden
def init_hidden(self):
# 初始化隐藏层参数hidden
return torch.zeros(1, self.hidden_size)
参考资料
morvanzhou.github.io/tu
easyai.tech/ai-definiti
代码地址:
github.com/zy1996code/n
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~