torch神经网络图

我们可以用神经网络图(NNGraph)创建简单网络。
首先我们需要:

require 'nngraph';

简单网络

我们先创建一个简单的前馈网络。

-- it is common style to mark inputs with identity nodes for clarity.
input = nn.Identity()()

-- each hidden layer is achieved by connecting the previous one
-- here we define a single hidden layer network
h1 = nn.Tanh()(nn.Linear(20, 10)(input))
output = nn.Linear(10, 1)(h1)
mlp = nn.gModule({input}, {output})
x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x, dx)
mlp:accGradParameters(x, dx)
-- draw graph (the forward graph, '.fg')
-- this will produce an SVG in the runtime directory
graph.dot(mlp.fg, 'MLP', 'MLP')
itorch.image('MLP.svg')

1.jpg

节点名称

当我们创建复杂网络的时候,节点名称设置可以化繁为简。

local function get_network()
    -- it is common style to mark inputs with identity nodes for clarity.
    local input = nn.Identity()()
   -- each hidden layer is achieved by connecting the previous one
    -- here we define a single hidden layer network
    local h1 = nn.Linear(20, 10)(input)
    local h2 = nn.Sigmoid()(h1)
    local output = nn.Linear(10, 1)(h2)
     -- the following function call inspects the local variables in this
    -- function and finds the nodes corresponding to local variables.
    nngraph.annotateNodes()
    return nn.gModule({input}, {output}) 
end
mlp = get_network()
x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x, dx)
mlp:accGradParameters(x, dx)
-- draw graph (the forward graph, '.fg')
-- this will produce an SVG in the runtime directory
graph.dot(mlp.fg, 'MLP', 'MLP_Annotated')
itorch.image('MLP_Annotated.svg')

2.jpg

确认运行时错误

-- We need to set debug flag to true
nngraph.setDebug(true)
local function get_network()
    -- it is common style to mark inputs with identity nodes for clarity.
    local input = nn.Identity()()
 -- each hidden layer is achieved by connecting the previous one
    -- here we define a single hidden layer network
    local h1 = nn.Linear(20, 10)(input)
    local h2 = nn.Sigmoid()(h1)
    local output = nn.Linear(10, 1)(h2) 
    -- the following function call inspects the local variables in this
    -- function and finds the nodes corresponding to local variables.
    nngraph.annotateNodes()
    return nn.gModule({input}, {output}) 
end
mlp = get_network()
mlp.name = 'MyMLPError'
x = torch.rand(15) -- note that this input will cause runtime error
-- We do protected call to avoid real error interrupting the notebook
local o, err = pcall(function() mlp:updateOutput(x) end)
itorch.image('MyMLPError.svg')

3.jpg
我们很容易看到h1出了点问题。

更复杂的例子

下面我们来创建RNN的核心模块。

function get_rnn(input_size, rnn_size)
      -- there are n+1 inputs (hiddens on each layer and x)
    local input = nn.Identity()()
    local prev_h = nn.Identity()()
  -- RNN tick
    local i2h = nn.Linear(input_size, rnn_size)(input)
    local h2h = nn.Linear(rnn_size, rnn_size)(prev_h)
    local added_h = nn.CAddTable()({i2h, h2h})
    local next_h = nn.Tanh()(added_h) 
    nngraph.annotateNodes()
    return nn.gModule({input, prev_h}, {next_h})
end
local rnn_net = get_rnn(128, 128)
graph.dot(rnn_net.fg, 'rnn_net', 'rnn_net')
itorch.image('rnn_net.svg')

4.jpg

在时域上连接

下面我们把RNN核心模块在时域上进行连接:

local function get_rnn2(input_size, rnn_size)
    local input1 = nn.Identity()()
    local input2 = nn.Identity()()
    local prev_h = nn.Identity()()
    local rnn_net1 = get_rnn(128, 128)({input1, prev_h})
    local rnn_net2 = get_rnn(128, 128)({input2, rnn_net1})
    nngraph.annotateNodes()
    return nn.gModule({input1, input2, prev_h}, {rnn_net2})
end
local rnn_net2 = get_rnn2(128, 128)
graph.dot(rnn_net2.fg, 'rnn_net2', 'rnn_net2')
itorch.image('rnn_net2.svg')

4.jpg

更多debug方法

即使用不同的命名,网络图也会变得很复杂,我们还可以用标记来标记路径。

local function get_rnn2(input_size, rnn_size)
    local input1 = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local input2 = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local prev_h = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local rnn_net1 = get_rnn(128, 128)({input1, prev_h}):annotate{graphAttributes = {style='filled', fillcolor='yellow'}}
    local rnn_net2 = get_rnn(128, 128)({input2, rnn_net1}):annotate{graphAttributes = {style='filled', fillcolor='green'}}
    nngraph.annotateNodes()
    return nn.gModule({input1, input2, prev_h}, {rnn_net2})
end
local rnn_net3 = get_rnn2(128, 128)
graph.dot(rnn_net3.fg, 'rnn_net3', 'rnn_net3')
itorch.image('rnn_net3.svg')

5.jpg

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值