构建方法
nngraph包在构建更加复杂的网络极其有用。毕竟是有点类似”静态图“了。
简单来说就是以前加网络需要不断add,现在用了nngraph,只要不断”一“就行了。
h1 = - nn.Linear(20,10)
h2 = h1
- nn.Tanh()
- nn.Linear(10,10)
- nn.Tanh()
- nn.Linear(10, 1)
mlp = nn.gModule({h1}, {h2})
注意点:
1. 刚开始时需要用”-“来初始化。
2. 在nn.gModule中写入两个table,第一个table表示输入节点,第二个是输出节点。
当然,这两个table都可以有多个值。值得注意的是。这两个table必须是”node“。不能是任何其他的。
以Unet结构为例子:
function defineG_unet(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 256 x 256
-- 初始化时先用“-”
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e7} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e6} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e5} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1,