总说
先附加几个有用的。
replace(function)
这个可以是将function应用到net中每一层上。比如,我们可以将model中的nn.Dropout
替换成nn.Identity()
,显然传入的参数module
随便写成什么变量,都指某一层。
model:replace(function(module)
if torch.typename(module) == 'nn.Dropout' then
return nn.Identity()
else
return module
end
end)
apply(function)
这个和上面的类似,也是对每一层进行操作。
local function weights_init(m)
local name = torch.type(m)
if name:find('Convolution') then
m.weight:normal(0.0, 0.02)
m.bias:fill(0)
elseif name:find('BatchNormalization') then
if m.weight then m.weight:normal(1.0, 0.02) end
if m.bias then m.bias:fill(0) end
end
end
-- define net
...
net:apply(weights_init) --这样就把net的每一层自定义初始化参数了
remove和insert
有时候我们想直接移除某一层,或是中间添加一层。
model = nn.Sequential()
model:add(nn.Linear(10, 20))
model:add(nn.Linear(20, 20))
model:add(nn.Linear(20, 30))
-- 直接写移除的层的index即可
model:remove(2)
> model
nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.Linear(10 -> 20)
(2): nn.Linear(20 -> 30)
}
对于insert,
model = nn.Sequential()
model:add(nn.Linear(10, 20))
model:add(nn.Linear(20, 30))
-- 希望插入的Linear(20,20)在model中的第二层
model:insert(nn.Linear(20, 20), 2