neuraltalk2-代码解析-(2)

昨天发的博客居然有人进来浏览,心里还是很开心的,我也仔细看了解析(1),发现里面有很多问题,昨天博客用了7,8个小时完成,比较匆忙,由于我这博客主要是为了和初学torch平台,image caption的新手分享,所以我会在标签上加关于新手的标签,我也会在解析代码过程简略分析torch中的各各方法,函数,努力争取让没学过torch,lua的人也能读懂,加油!

这篇博客来解析neuraltalk2中,misc文件夹下LSTM.lua这个部分,这个部分较为简单,对于我这种新手来说也颇为重要,所以今天我也不会吝惜笔墨,和小白们一起闯一闯(其实我已经看过一些其他LSTM代码了^_^),我先贴代码。

如果大家读这篇博客,想必都对LSTM结构有清晰的了解。一切关于Torch方法的解释我会在代码块后面给出
require 'nn'
require 'nngraph'

local LSTM = {}
--这个LSTM层的构建由4个参数决定,分别为input_size,output_size,rnn_size,n,dropout
--input_size为LSTM输入向量的大小
--output_size为输出向量的大小
--rnn_size为中间层的大小,也可以理解为隐藏层的大小,或每层输出大小
--n为LSTM的层数,可以说是从输入一个词向量,到输出一个词向量预测,需要进过几个循环
--dropout为处理过拟合的一种方式,它控制着有多少个神经元不更新,但其不停止工作,是参与运算的,dropout的值取值范围是[0,1),若dropout=0.5,表示有50%的权值在反向传播时是不会更新的
function LSTM.lstm(input_size, output_size, rnn_size, n, dropout)
  dropout = dropout or 0 

  -- there will be 2*n+1 inputs
  local inputs = {}
  --这里是插入输入向量,这个向量可以理解为词向量,inputs为输入索引。
  table.insert(inputs, nn.Identity()()) -- indices giving the sequence of symbols
  --这里是插入prev_c与prev_h,因为LSTM的输入是由输入词向量x,前一层的hide单元,与前一层的cell单元组成,这里为每一LSTM层都创建了输入的接口
  for L = 1,n do
    table.insert(inputs, nn.Identity()()) -- prev_c[L]
    table.insert(inputs, nn.Identity()()) -- prev_h[L]
  end
  local x, input_size_L
  local outputs = {}
  for L = 1,n do
    -- c,h from previos timesteps
    --从输入索引中取出prev_h与prev_c,在每一层储存时是cell在前,hide在后
    local prev_h = inputs[L*2+1]
    local prev_c = inputs[L*2]
    -- the input to this layer
    if L == 1 then 
      --如果是初始输入,公式中输入向量x因为这模块接口的输入词向量,输入词向量的接口为inputs[1]
      x = inputs[1]
      input_size_L = input_size
    else
      --若不是第一次输入,则输入接口为上一层的输出 
      x = outputs[(L-1)*2] 
      --如果dropout大于0,对输入x接口进行dropout处理
      if dropout > 0 then x = nn.Dropout(dropout)(x):annotate{name='drop_' .. L} end -- apply dropout, if any
      input_size_L = rnn_size
    end
    -- evaluate the input sums at once for efficiency
    --这行代码是创建一个从input_size_l,到4*rnn_size维度的线性变换,这里为什么是4*rnn_size,因为在LSTM中output,input,forget,cell gates公式中都运用到了w*i这个线性变换,写在一起,为了高效
    local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
    --同上行代码,pre_h*w在四个部分同样用到
    local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L}
    --创建一个module,将对i的线性变换与对h的线性变换相加等同于prev_h*w+i*w
    --在torch代码编写是无需考虑bias,为什么呢,我暂时还没深究
    local all_input_sums = nn.CAddTable()({i2h, h2h})
    --创建一个reshape模块,将all_input_sums按照rnn_size分解为4个,分别为input,output,forget,cell的输入的一部分
    local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
    --这是将reshaped彻底分解出来
    local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
    -- decode the gates
    -- 对input,output,forget gates作非线性处理,这些变量的父类都是nn.module,以下代码与LSTM中公式的对应,相信大家都能找到对应的公式
    local in_gate = nn.Sigmoid()(n1)
    local forget_gate = nn.Sigmoid()(n2)
    local out_gate = nn.Sigmoid()(n3)
    -- decode the write inputs
    local in_transform = nn.Tanh()(n4)
    -- perform the LSTM update
    local next_c           = nn.CAddTable()({
        nn.CMulTable()({forget_gate, prev_c}),
        nn.CMulTable()({in_gate,     in_transform})
      })
    -- gated cells form the output
    local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})

    table.insert(outputs, next_c)
    table.insert(outputs, next_h)
  end

  -- set up the decoder
  local top_h = outputs[#outputs]
  if dropout > 0 then top_h = nn.Dropout(dropout)(top_h):annotate{name='drop_final'} end
  local proj = nn.Linear(rnn_size, output_size)(top_h):annotate{name='decoder'}
  --将最后的输出做logsoftmax处理,分类
  local logsoft = nn.LogSoftMax()(proj)
  table.insert(outputs, logsoft)
  --最后返回的是一个gModule类型的容器,输入接口为inputs,outputs
  return nn.gModule(inputs, outputs)
end

return LSTM
  • nn.dropout()

    • nn.dropout这个类继承于nn.module,例n2=nn.dropout(p)(n1),将输入接口n1,dropout化,如果不懂什么是dropout,参照dropout
  • annotate()

    • 很好理解,就是添加注释,没其他功能。
  • nn.Linear()

    • 这个类继承于nn.module(),nn.Linear(a,b,bias)一般条件下bias可以省略,他创建了一个从a维到b维的线性映射神经网络模块
  • nn.CAddTable()

    • 这个类继承于nn.module(),nn.CAddTable(ip),这个模块将ip(类型为List)中的各各模块做直接的每个元素对应相加
  • nn.CMulTable

    • 这个模块同上个模块,不过是做乘法
  • nn.reshape

    • 这个模块同样继承于nn.module,他的功能是将输入模块重新改变结构,如上文代码中其把原本是4*rnn_size长度,一维的模块,重新reshape变成一个2维,[4,rnn_size]的模块
  • nn.SplitTable

    • 继承于nn.module(n),表示将模块从第n维分割,返回一个存储子模块的table

    大家会发现弄懂,搞清楚nn.module这个类多么重要,如果我觉得我比较全面弄懂了nn.module我会写博客的。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值