[torch]同时更新多个seqlstm

nn.SeqLSTM 在backward的时候需要一些中间参数, 这些中间参数是由seqLSTM:forward(input)时生成的,并且,每forward一些,这些中间参数就会被重置.

maptable

require 'nn'
require 'rnn'
require 'os'

local batch_size = 5
local feat_dim = 6
local hidden_size = 4
local seq_len = 10
local num = 10 
local lr = 0.01
--------initialize model
--[[
local model = nn.SeqLSTM(feat_dim,hidden_size)
--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))
model:clearState()
torch.save("model_init_seqLSTM.t7", model)
os.exit()
--]]
local model = torch.load("model_init_seqLSTM.t7")
local model1 = torch.load("model_init_seqLSTM.t7")
local model2= torch.load("model_init_seqLSTM.t7")
local model3= torch.load("model_init_seqLSTM.t7")

--local params,  gradparams =  model:getParameters():
--local criterion = nn.SequencerCriterion(nn.MSECriterion())
-------------input, label
local input = {}
local gradOut = {}
for i = 1, num do
        x = torch.randn(seq_len,batch_size,feat_dim)
        y = torch.randn(seq_len,batch_size,hidden_size)
        table.insert(input,x)
        table.insert(gradOut,y)
end
-------------map
local map = nn.MapTable():add(model)
local out = map:forward(input)
map:backward(input,gradOut)
map:updateParameters(lr)
map:zeroGradParameters()
----------model single
loss = 0
--[[
out={}
for i = 1,num do
    out[i] = model1:forward(input[i])
end
--]]
out=model1:forward(input[num])
for i = num,1,-1 do
        gradInputs = model1:backward(input[i], gradOut[i])
    model1:updateParameters(lr)
end
model1:forget()
model1:zeroGradParameters()
----------model main(true value)
loss = 0
out={}
for i = 1,num do
    out = model2:forward(input[i])
        gradInputs = model2:backward(input[i], gradOut[i])
model2:updateParameters(lr)
model2:forget()
model2:zeroGradParameters()
end
----------model main2(true value)
loss = 0
out={}
for i = num,1,-1 do
    out = model3:forward(input[i])
        gradInputs = model3:backward(input[i], gradOut[i])
model3:updateParameters(lr)
model3:forget()
model3:zeroGradParameters()
end
----------forward again
out = map:forward(input)
out_single,loss = {},0
for i, k in pairs(out) do
        out1 = model1:forward(input[i])
    out2 = model2:forward(input[i])
    out3 = model3:forward(input[i])
    print(i)
    --print(out2) --true value
    --print(out3) --true value2
    --print(k)    --maptable
    --print(out1) --model single
    print(out2+out3)
        print(k*2)    --maptable
    --print(out1*2) --model single (this one is quite different from above two methods)
end

new way

require 'nn'
require 'rnn'
require 'os'

local batch_size = 5
local feat_dim = 6
local hidden_size = 4
local seq_len = 10
local num = 2 
local lr = 0.01
--------initialize model
--[[
local model = nn.SeqLSTM(feat_dim,hidden_size)
--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))
model:clearState()
torch.save("model_init_seqLSTM.t7", model)
os.exit()
--]]
local model1 = torch.load("model_init_seqLSTM.t7")
--local params,  gradparams =  model:getParameters()
--local criterion = nn.SequencerCriterion(nn.MSECriterion())
-------------input, label
local input = {}
local gradOut = {}
for i = 1, num do
    x = torch.randn(seq_len,batch_size,feat_dim)
    y = torch.randn(seq_len,batch_size,hidden_size)
    table.insert(input,x)
    table.insert(gradOut,y)
end
-------------------new way
local models,out,gradInputs = {},{},{}
for i = num,1,-1 do
    local model = torch.load("model_init_seqLSTM.t7")
    models[i] = model
    if i < num then
        params_cur, gradParams_cur = models[i]:getParameters()
        params_updated, gradParams_updated = models[i+1]:getParameters()
        for j = 1, (#params_cur)[1] do
            params_cur[j] = params_updated[j]
            gradParams_cur[j] = gradParams_updated[j]
        end
    end
    out[i] = models[i]:forward(input[i])
    gradInputs[i] = models[i]:backward(input[i],gradOut[i])
    models[i]:updateParameters(0.01)
    models[i]:forget()
    models[i]:zeroGradParameters()
end
params_updated, gradParams_updated = models[1]:getParameters()
for i = 2,num do
    params_cur, gradParams_cur = models[i]:getParameters()
        for j = 1, (#params_cur)[1] do
            params_cur[j] = params_updated[j]
                gradParams_cur[j] = gradParams_updated[j]
        end
end
----------------------true one
loss = 0
out={}
for i = num,1,-1 do
        out = model1:forward(input[i])
        gradInputs = model1:backward(input[i], gradOut[i])
    model1:updateParameters(lr)
    model1:forget()
    model1:zeroGradParameters()
end
----------check results
for i = 1,num do
    out0 = models[i]:forward(input[i])
    out1 = model1:forward(input[i])

    print(i,out0,out1)
end
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值