[torch]creat a new layer

0. one important point

when write updateGradInput, must use self.GradInput instead of other names to return. !!!!!!!!!!

1. don’t use this.

Pay attention to : and . in lua function.

MyMean.lua


require 'nn'
require 'os'
local MyMean, parent = torch.class('nn.MyMean', 'nn.Sum')

function MyMean:__init(maskZero)
        --input: batch x rdim x cdim
        --output: batch x cdim
        --parent.__init(self,2)         --wrong
        parent.__init(parent, 2)        --right
        parent:__init(2)                --right
        --print(self.dimension,self.sizeAverage)
        print(parent.dimension, parent.sizeAverage)

        print("finish init!")
        self.maskZero = maskZero or false
end

function MyMean:_getNonZeroDimension(input)
        local dimension = 2
        local rows = nn.Sum(dimension+1)(input)
        local NumNonZero = nn.Sum(dimension)(rows:ne(0):double())
        NumNonZero_rep = nn.Replicate(input:size()[3],2)(NumNonZero)

        return NumNonZero_rep
end

function MyMean:updateOutput(input)
        --self.output = parent:updateOutput(input)--right
        self.output = parent.updateOutput(parent,input)  --right
        if self.maskZero then
                self.NumNonZero_rep = self._getNonZeroDimension(input)
                self.output:cdiv(self.NumNonZero_rep)
        else
                self.output:div(input:size()[2])
        end
        return self.output
end

function MyMean:updateGradInput(input, gradOutput)
        self.gradInput = parent:updateGradInput(input, gradOutput)
        if self.maskZero then
                self.gradInput:cdiv(self.NumNonZero_rep)
        else
                self.gradInput:div(input:size()[2])
        end

        return self.gradInput
end

test.lua

require 'nn'
require 'os'
require 'rnn'
require 'MyMean'

batch = 5
shot = 3
hidden=2
input = torch.rand(batch, shot, hidden)
model = nn.MyMean(false)
out = model:forward(input)
print(out)

2

MyMean.lua

require 'nn'
--[[
require 'cunn'
require 'cutorch'
--]]
require 'os'
local MyMean, parent = torch.class('nn.MyMean', 'nn.Sum')

function MyMean:__init(maskZero)
    --input: batch x rdim x cdim
    --output: batch x cdim
    --input: batch x shot x hidden
    --output: batch x hidden
    parent.__init(self)
    self.sum_model = nn.Sum(2)
    self.maskZero = maskZero or false
end

function MyMean:_getNonZeroDimension(input)
    local module1 = nn.Sum(3)
    local module2 = nn.Sum(2)
    local module3 = nn.Replicate(input:size()[3],2)
    if input:type():find('torch%.Cuda.*Tensor') then
        module1:cuda()
        module2:cuda()
        module3:cuda()
        end
    local rows = module1:forward(input)
    local nz_rows = rows:ne(0):double()
    if input:type():find('torch%.Cuda.*Tensor') then
        nz_rows = nz_rows:cuda()
    end
    local NumNonZero = module2:forward(nz_rows)
    NumNonZero_rep = module3:forward(NumNonZero)

    return NumNonZero, NumNonZero_rep
end

function MyMean:_RepNonZero(input, batch, shot, hidden)
    local reshapes = nn.ParallelTable()
    for i = 1, batch do
            local res = nn.Reshape(shot,hidden)
        reshapes:add(res)
    end
    local m = nn.Sequential()
    local rep = nn.Replicate(shot*hidden)
    m:add(rep)
    m:add(nn.SplitTable(2))
    m:add(reshapes)

    local m2 = nn.Sequential()
    m2:add(nn.JoinTable(1))
    m2:add(nn.Reshape(batch,shot,hidden))

    if input:type():find('torch%.Cuda.*Tensor') then
        m:cuda()
        m2:cuda()
    end
    out = m:forward(input)
    out2 = m2:forward(out)

    return out2
end

function MyMean:updateOutput(input)
    self.output = self.sum_model:forward(input) 
    if self.maskZero then
        self.NumNonZero, NumNonZero_rep = self:_getNonZeroDimension(input)
        self.output:cdiv(NumNonZero_rep)
    else
        self.output:div(input:size()[2])
    end
    return self.output
end

function MyMean:updateGradInput(input, gradOutput)
    self.gradInput = self.sum_model:backward(input, gradOutput)
    if self.maskZero then
        NumNonZero_rep = MyMean:_RepNonZero(self.NumNonZero,(#input)[1],(#input)[2],(#input)[3])
        self.gradInput:cdiv(NumNonZero_rep)
    else
        self.gradInput:div(input:size()[2])
    end

    return self.gradInput
end

test.lua

require 'nn'
require 'rnn'
require 'MyMean'
batch = 5
shot = 3
hidden_size = 4
event_num = 2
dropout_rate = 0.1

        --local H2 = nn.Identity()()                              --batch x shot_num x hidden_size
        local H2 = torch.rand(batch,shot,hidden_size)
        --table.insert(inputs, H2)
        local H3 = nn.MyMean()(H2)                      --batch x hidden_size
        local H_bar = nn.Replicate(event_num,2)(H3)     --batch x event_num x hidden_size
        local reduceDim = nn.ParallelTable()
        for i = 1, event_num do
                reduceDim:add(nn.Linear(hidden_size, 1))
        end
        m = nn.Sequential()
        m:add(nn.SplitTable(1,2))
        m:add(reduceDim)
        local out = nn.JoinTable(1,1)(m(H_bar))                 --batch x event_num

        local out_dropout = nn.Dropout(dropout_rate)(out)
        local softscore = nn.LogSoftMax()(out_dropout)

        --table.insert(outputs,softscore)
        --attenmodule = nn.gModule(inputs, outputs)
        print(#softscore)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值