https://www.cnblogs.com/hellcat/p/9047618.html
『MXNet』第四弹_Gluon自定义层
一、不含参数层
通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward
函数里,
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
from
mxnet
import
nd, gluon
from
mxnet.gluon
import
nn
class
CenteredLayer(nn.Block):
def
__init__(
self
,
*
*
kwargs):
super
(CenteredLayer,
self
).__init__(
*
*
kwargs)
def
forward(
self
, x):
return
x
-
x.mean()
# 直接使用这个层
layer
=
CenteredLayer()
# layer(nd.array([1, 2, 3, 4, 5]))
# 构建更复杂模型
net
=
nn.Sequential()
net.add(nn.Dense(
128
))
net.add(nn.Dense(
10
))
net.add(CenteredLayer())
# 初始化、运行……
net.initialize()
y
=
net(nd.random.uniform(shape
=
(
4
,
8
)))
|
二、含参数层
注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定
见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
from
mxnet
import
gluon
from
mxnet.gluon
import
nn
class
MyDense(nn.Block):
def
__init__(
self
, units, in_units,
*
*
kwargs):
super
(MyDense,
self
).__init__(
*
*
kwargs)
self
.weight
=
self
.params.get(
'weight'
, shape
=
(in_units, units))
self
.bias
=
self
.params.get(
'bias'
, shape
=
(units,))
def
forward(
self
, x):
linear
=
nd.dot(x,
self
.weight.data())
+
self
.bias.data()
return
nd.relu(linear)
# 实际运行
dense
=
MyDense(
5
, in_units
=
10
)
|
如果不想使用ParameterDict类则需要一下操作
1
2
3
|
# self.weight = self.params.get('weight', shape=(in_units, units))
self
.weight
=
gluon.Parameter(
'weight'
, shape
=
(in_units, units))
self
.params.update({
'weight'
:
self
.weight})
|
否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。
有关这一点详见下面:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
def
__init__(
self
, conv_arch, dropout_keep_prob,
*
*
kwargs):
super
(SSD,
self
).__init__(
*
*
kwargs)
self
.vgg_conv
=
nn.Sequential()
self
.vgg_conv.add(repeat(
*
conv_arch[
0
], pool
=
False
))
[
self
.vgg_conv.add(repeat(
*
conv_arch[i]))
for
i
in
range
(
1
,
len
(conv_arch))]
# 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
# self.vgg_conv = tuple([repeat(*conv_arch[i])
# for i in range(len(conv_arch))])
# 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
# _ = [self.params.update(block.collect_params()) for block in self.vgg_conv]
def
forward(
self
, x, feat_layers):
end_points
=
{
'block0'
: x}
for
(index, block)
in
enumerate
(
self
.vgg_conv):
end_points.update({
'block{:d}'
.
format
(index
+
1
): block(end_points[
'block{:d}'
.
format
(index)])})
return
end_points
|
属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。
测试代码:
1
2
3
4
5
6
7
8
|
if
__name__
=
=
'__main__'
:
ssd
=
SSD(conv_arch
=
((
2
,
64
), (
2
,
128
), (
3
,
256
), (
3
,
512
), (
3
,
512
)),
dropout_keep_prob
=
0.5
)
ssd.initialize()
X
=
mx.ndarray.random.uniform(shape
=
(
1
,
1
,
304
,
304
))
import
pprint as pp
pp.pprint([x[
1
].shape
for
x
in
ssd(X).items()])
|
自行验证即可。