事情的起因是这样的:我在MXNet
中创建了一个网络模型,我需要将模型的参数全部乘以一个数字。以一个AlexNet
模型为例:
import mxnet as mx
from mxnet.gluon import nn
from mxnet import gluon
dropoutRatio = 0.5
numClass = 10
learningRate = 0.01
net = nn.Sequential()
net.add(
# strides = 4 means the same with strides = (4, 4)
# padding = 2 means the same with padding = (2, 2)
nn.Conv2D(channels = 96, kernel_size = 11, strides = 4, activation = 'relu'), # conv0
nn.MaxPool2D(pool_size = 3, strides = 2), # pool0
nn.Conv2D(channels = 256, kernel_size = 5, padding = 2, activation = 'relu'), # conv1
nn.MaxPool2D(pool_size = 3, strides = 2), # pool1
nn.Conv2D(channels = 384, kernel_size = 3, padding = 1, activation = 'relu'), # conv2
nn.Conv2D(channels = 384, kernel_size = 3, padding = 1, activation = 'relu'), # conv3
nn.Conv2D(channels = 256, kernel_size = 3, padding = 1, activation = 'relu'), # conv4
nn.MaxPool2D(pool_size = 3, strides = 2), # pool2
nn.Dense(4096, activation = 'relu'), nn.Dropout(dropoutRatio),
nn.Dense(4096, activation = 'relu'), nn.Dropout(dropoutRatio),
nn.Dense(numClass))
net.initialize(force_reinit = True, init = mx.init.Xavier())
loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learningRate})
# 这里假设输入是一张224*224的黑白图片,这里首先进行一次计算是为了取消延迟初始化,此时整个网络的参数形状已经确定
x = mx.nd.ones(shape = (1, 1, 224, 224)) # 两个1分别代表批量大小和通道数(黑白照片通道数为1)
y = net(x)
y
[[-0.00624995 -0.02138521 0.01968115 0.02071718 0.00322351 -0.00585919
-0.00409283 0.00211513 0.00078388 -0.02968722]]
<NDArray 1x10 @cpu(0)>
后续的操作我需要将参数全部乘以一个数,于是我做了如下的操作:
number = 0.5
for param in trainer._params:
param.list_data() * number
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-4-9d21a053455d> in <module>
1 number = 0.5
2 for param in trainer._params:
----> 3 param.list_data() * number
TypeError: can't multiply sequence by non-int of type 'float'
此处提示序列不能乘以一个non-int
类型,此时反应过来list_data()
获取的是一个list
而list
乘以一个整数是将list
进行复制多份,于是引出了标题的问题,如何将一个未知维度和大小的list
中的元素乘以某一个数,于是我采取了下面的操作:
def multiply(parameterList, weightRatio):
if isinstance(parameterList, list):
for parameter in parameterList:
multiply(parameter, weightRatio)
else:
parameterList *= weightRatio
for param in trainer._params:
multiply(param.list_data(), number)
上面代码的原理是通过递归来进行操作,当传入的参数是一个list
类型,那么就应该将每一个元素都进行扩大,否则直接将当前元素进行扩大即可。但是上面的解决方法并不优雅,后面发现可以通过numpy
来实现该操作:
import numpy as np
l = [[1, 2, 3, 4], [2, 3, 4, 5]]
l = (np.array(l) * 0.5).tolist()
l
[[0.5, 1.0, 1.5, 2.0], [1.0, 1.5, 2.0, 2.5]]
但是发现这个操作并不适用list_data()
获取的数据。