[torch]Save initial state(fastlstm)

installpath/torch/rnn/Fastlstm.lua

before

local FastLSTM, parent = torch.class("nn.FastLSTM", "nn.LSTM")

-- set this to true to have it use nngraph instead of nn
-- setting this to true can make your next FastLSTM significantly faster
FastLSTM.usenngraph = false
FastLSTM.bn = false

function FastLSTM:__init(inputSize, outputSize, rho, eps, momentum, affine)
   --  initialize batch norm variance with 0.1
   self.eps = eps or 0.1
   self.momentum = momentum or 0.1 --gamma
   self.affine = affine == nil and true or affine

   parent.__init(self, inputSize, outputSize, rho) 
end

function FastLSTM:buildModel()
   -- input : {input, prevOutput, prevCell}
   -- output : {output, cell}

   -- Calculate all four gates in one go : input, hidden, forget, output
   self.i2g = nn.Linear(self.inputSize, 4*self.outputSize)
   self.o2g = nn.LinearNoBias(self.outputSize, 4*self.outputSize)

   if self.usenngraph or self.bn then
      require 'nngraph'
      return self:nngraphModel()
   end

   local para = nn.ParallelTable():add(self.i2g):add(self.o2g)
   local gates = nn.Sequential()
   gates:add(nn.NarrowTable(1,2))
   gates:add(para)
   gates:add(nn.CAddTable())

   -- Reshape to (batch_size, n_gates, hid_size)
   -- Then slize the n_gates dimension, i.e dimension 2
   gates:add(nn.Reshape(4,self.outputSize))
   gates:add(nn.SplitTable(1,2))
   local transfer = nn.ParallelTable()
   transfer:add(nn.Sigmoid()):add(nn.Tanh()):add(nn.Sigmoid()):add(nn.Sigmoid())
   gates:add(transfer)

   local concat = nn.ConcatTable()
   concat:add(gates):add(nn.SelectTable(3))
   local seq = nn.Sequential()
   seq:add(concat)
   seq:add(nn.FlattenTable()) -- input, hidden, forget, output, cell

   -- input gate * hidden state
   local hidden = nn.Sequential()
   hidden:add(nn.NarrowTable(1,2))
   hidden:add(nn.CMulTable())

   -- forget gate * cell
   local cell = nn.Sequential()
   local concat = nn.ConcatTable()
   concat:add(nn.SelectTable(3)):add(nn.SelectTable(5))
   cell:add(concat)
   cell:add(nn.CMulTable())

   local nextCell = nn.Sequential()
   local concat = nn.ConcatTable()
   concat:add(hidden):add(cell)
   nextCell:add(concat)
   nextCell:add(nn.CAddTable())

   local concat = nn.ConcatTable()
   concat:add(nextCell):add(nn.SelectTable(4))
   seq:add(concat)
   seq:add(nn.FlattenTable()) -- nextCell, outputGate

   local cellAct = nn.Sequential()
   cellAct:add(nn.SelectTable(1))
   cellAct:add(nn.Tanh())
   local concat = nn.ConcatTable()
   concat:add(cellAct):add(nn.SelectTable(2))
   local output = nn.Sequential()
   output:add(concat)
   output:add(nn.CMulTable())

   local concat = nn.ConcatTable()
   concat:add(output):add(nn.SelectTable(1))
   seq:add(concat)

   return seq
end

after

require 'hdf5'
local FastLSTM, parent = torch.class("nn.FastLSTM", "nn.LSTM")

-- set this to true to have it use nngraph instead of nn
-- setting this to true can make your next FastLSTM significantly faster
FastLSTM.usenngraph = false
FastLSTM.bn = false

function FastLSTM:__init(inputSize, outputSize, rho, eps, momentum, affine, initialfile, ifLoad)
   --  initialize batch norm variance with 0.1
   self.eps = eps or 0.1
   self.momentum = momentum or 0.1 --gamma
   self.affine = affine == nil and true or affine

   self.initialfile = initialfile or 0
   self.ifLoad = ifLoad --"1" means to load initialfile to initialize. "0" means to save weights to initialfile. 

   parent.__init(self, inputSize, outputSize, rho) 
end

function FastLSTM:buildModel()
   -- input : {input, prevOutput, prevCell}
   -- output : {output, cell}

   -- Calculate all four gates in one go : input, hidden, forget, output
   self.i2g = nn.Linear(self.inputSize, 4*self.outputSize)
   self.o2g = nn.LinearNoBias(self.outputSize, 4*self.outputSize)

   if self.initialfile ~= 0 then
     if self.ifLoad then
        -- use hdf5 to initialize
        local myFile = hdf5.open(self.initialfile, 'r')
        self.i2g.weight = myFile:read('i2g_weight'):all()
        self.i2g.bias = myFile:read('i2g_bias'):all()
        self.i2g.gradWeight = myFile:read('i2g_gradWeight'):all()
        self.i2g.gradBias = myFile:read('i2g_gradBias'):all()
        self.o2g.weight = myFile:read('o2g_weight'):all()
        self.o2g.gradWeight = myFile:read('o2g_gradWeight'):all()
        myFile:close()
     else
    local myFile = hdf5.open(self.initialfile, 'w') 
    myFile:write('i2g_weight',self.i2g.weight)
    myFile:write('i2g_bias',self.i2g.bias)
    myFile:write('i2g_gradWeight',self.i2g.gradWeight)
    myFile:write('i2g_gradBias',self.i2g.gradBias)
    myFile:write('o2g_weight',self.o2g.weight)
    myFile:write('o2g_gradWeight',self.o2g.gradWeight)
    myFile:close()
     end  
   end

   if self.usenngraph or self.bn then
      require 'nngraph'
      return self:nngraphModel()
   end

   --...
end

after

cd ~/installpath/torch/rnn
rm -r build/
luarocks make rocks/rnn-scm-1.rockspec
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值