戳simple-recurrent-network.lua下载源码。
Torch的重要作者之一,Nicolas Leonard对Torch中的nn库进行扩展,发布了rnn库。这个库包含RNN, LSTM, GRU, BRNN, BLSTM等能够处理时序和记忆的网络。
本文以其中Simple Recurrent Network源码为例,讲解Torch中使用RNN1的基本方式。x是自然数序列。
概述
这篇代码实现的功能是:输入一个序列x,能够输出一个序列y。
在真实应用中,输入输出序列可能是不同语种的一句话。这里做了极大简化:y=x+1 mod 10。
网络结构
为了使用RNN,需要包含rnn库:
require 'rnn'
首先设定关键参数。RNN的核心是隐变量 h h h,记录了系统当前时刻的状态。
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 7 -- 隐变量维度
nIndex = 10 -- 输出分类数量
lr = 0.1 --学习率
Recurrent类型
下面可以建立rnn库中的Recurrent类型(参看Recurrent.lua)建立核心模块r
:
local r = nn.Recurrent(
hiddenSize, -- start
nn.LookupTable(nIndex, hiddenSize), -- input
nn.Linear(hiddenSize, hiddenSize), -- feedback
nn.Sigmoid(), -- transfer
rho -- rho
)
创建Recurrent的四个参数意义如下:
start
- 指明隐变量维度。或者,指明从input
到transfer
模块之间的操作。
input
- 指明从输入到隐变量的操作。此处是个nn库中的查找表2。
feedback
- 从前一时刻transfer
之前到当前transfer
函数之前的操作,和input
结果逐元素相加。此处是个全连接层。
transfer
- 非线性函数,可以取ReLU, Sigmoid等。
rho
- 反向传播的步数。最多只向前考虑这么多步骤。
参考源码中的updateOutput
函数,给出了输入和输出的关系:
o u t p u t t = t r a n s f e r ( f e e d b