0. one important point
when write updateGradInput
, must use self.GradInput
instead of other names to return. !!!!!!!!!!
1. don’t use this.
Pay attention to :
and .
in lua function.
MyMean.lua
require 'nn'
require 'os'
local MyMean, parent = torch.class('nn.MyMean', 'nn.Sum')
function MyMean:__init(maskZero)
--input: batch x rdim x cdim
--output: batch x cdim
--parent.__init(self,2) --wrong
parent.__init(parent, 2) --right
parent:__init(2) --right
--print(self.dimension,self.sizeAverage)
print(parent.dimension, parent.sizeAverage)
print("finish init!")
self.maskZero = maskZero or false
end
function MyMean:_getNonZeroDimension(input)
local dimension = 2
local rows = nn.Sum(dimension+1)(input)
local NumNonZero = nn.Sum(dimension)(rows:ne(0):double())
NumNonZero_rep = nn.Replicate(input:size()[3],2)(NumNonZero)
return NumNonZero_rep
end
function MyMean:updateOutput(input)
--self.output = parent:updateOutput(input)--right
self.output = parent.updateOutput(parent,input) --right
if self.maskZero then
self.NumNonZero_rep = self._getNonZeroDimension(input)
self.output:cdiv(self.NumNonZero_rep)
else
self.output:div(input:size()[2])
end
return self.output
end
function MyMean:updateGradInput(input, gradOutput)
self.gradInput = parent:updateGradInput(input, gradOutput)
if self.maskZero then
self.gradInput:cdiv(self.NumNonZero_rep)
else
self.gradInput:div(input:size()[2])
end
return self.gradInput
end
test.lua
require 'nn'
require 'os'
require 'rnn'
require 'MyMean'
batch = 5
shot = 3
hidden=2
input = torch.rand(batch, shot, hidden)
model = nn.MyMean(false)
out = model:forward(input)
print(out)
2
MyMean.lua
require 'nn'
--[[
require 'cunn'
require 'cutorch'
--]]
require 'os'
local MyMean, parent = torch.class('nn.MyMean', 'nn.Sum')
function MyMean:__init(maskZero)
--input: batch x rdim x cdim
--output: batch x cdim
--input: batch x shot x hidden
--output: batch x hidden
parent.__init(self)
self.sum_model = nn.Sum(2)
self.maskZero = maskZero or false
end
function MyMean:_getNonZeroDimension(input)
local module1 = nn.Sum(3)
local module2 = nn.Sum(2)
local module3 = nn.Replicate(input:size()[3],2)
if input:type():find('torch%.Cuda.*Tensor') then
module1:cuda()
module2:cuda()
module3:cuda()
end
local rows = module1:forward(input)
local nz_rows = rows:ne(0):double()
if input:type():find('torch%.Cuda.*Tensor') then
nz_rows = nz_rows:cuda()
end
local NumNonZero = module2:forward(nz_rows)
NumNonZero_rep = module3:forward(NumNonZero)
return NumNonZero, NumNonZero_rep
end
function MyMean:_RepNonZero(input, batch, shot, hidden)
local reshapes = nn.ParallelTable()
for i = 1, batch do
local res = nn.Reshape(shot,hidden)
reshapes:add(res)
end
local m = nn.Sequential()
local rep = nn.Replicate(shot*hidden)
m:add(rep)
m:add(nn.SplitTable(2))
m:add(reshapes)
local m2 = nn.Sequential()
m2:add(nn.JoinTable(1))
m2:add(nn.Reshape(batch,shot,hidden))
if input:type():find('torch%.Cuda.*Tensor') then
m:cuda()
m2:cuda()
end
out = m:forward(input)
out2 = m2:forward(out)
return out2
end
function MyMean:updateOutput(input)
self.output = self.sum_model:forward(input)
if self.maskZero then
self.NumNonZero, NumNonZero_rep = self:_getNonZeroDimension(input)
self.output:cdiv(NumNonZero_rep)
else
self.output:div(input:size()[2])
end
return self.output
end
function MyMean:updateGradInput(input, gradOutput)
self.gradInput = self.sum_model:backward(input, gradOutput)
if self.maskZero then
NumNonZero_rep = MyMean:_RepNonZero(self.NumNonZero,(#input)[1],(#input)[2],(#input)[3])
self.gradInput:cdiv(NumNonZero_rep)
else
self.gradInput:div(input:size()[2])
end
return self.gradInput
end
test.lua
require 'nn'
require 'rnn'
require 'MyMean'
batch = 5
shot = 3
hidden_size = 4
event_num = 2
dropout_rate = 0.1
--local H2 = nn.Identity()() --batch x shot_num x hidden_size
local H2 = torch.rand(batch,shot,hidden_size)
--table.insert(inputs, H2)
local H3 = nn.MyMean()(H2) --batch x hidden_size
local H_bar = nn.Replicate(event_num,2)(H3) --batch x event_num x hidden_size
local reduceDim = nn.ParallelTable()
for i = 1, event_num do
reduceDim:add(nn.Linear(hidden_size, 1))
end
m = nn.Sequential()
m:add(nn.SplitTable(1,2))
m:add(reduceDim)
local out = nn.JoinTable(1,1)(m(H_bar)) --batch x event_num
local out_dropout = nn.Dropout(dropout_rate)(out)
local softscore = nn.LogSoftMax()(out_dropout)
--table.insert(outputs,softscore)
--attenmodule = nn.gModule(inputs, outputs)
print(#softscore)