Python中将未知维度list的每个元素乘以某一个数

事情的起因是这样的:我在MXNet中创建了一个网络模型,我需要将模型的参数全部乘以一个数字。以一个AlexNet模型为例:

import mxnet as mx
from mxnet.gluon import nn
from mxnet import gluon

dropoutRatio = 0.5
numClass = 10
learningRate = 0.01

net = nn.Sequential()
net.add(
    # strides = 4 means the same with strides = (4, 4)
    # padding = 2 means the same with padding = (2, 2)
    nn.Conv2D(channels = 96, kernel_size = 11, strides = 4, activation = 'relu'), # conv0
    nn.MaxPool2D(pool_size = 3, strides = 2), # pool0
    nn.Conv2D(channels = 256, kernel_size = 5, padding = 2, activation = 'relu'), # conv1
    nn.MaxPool2D(pool_size = 3, strides = 2), # pool1
    nn.Conv2D(channels = 384, kernel_size = 3, padding = 1, activation = 'relu'), # conv2
    nn.Conv2D(channels = 384, kernel_size = 3, padding = 1, activation = 'relu'), # conv3
    nn.Conv2D(channels = 256, kernel_size = 3, padding = 1, activation = 'relu'), # conv4
    nn.MaxPool2D(pool_size = 3, strides = 2), # pool2
    nn.Dense(4096, activation = 'relu'), nn.Dropout(dropoutRatio),
    nn.Dense(4096, activation = 'relu'), nn.Dropout(dropoutRatio),
    nn.Dense(numClass))

net.initialize(force_reinit = True, init = mx.init.Xavier())
loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learningRate})
# 这里假设输入是一张224*224的黑白图片,这里首先进行一次计算是为了取消延迟初始化,此时整个网络的参数形状已经确定
x = mx.nd.ones(shape = (1, 1, 224, 224)) # 两个1分别代表批量大小和通道数(黑白照片通道数为1)
y = net(x)
y
[[-0.00624995 -0.02138521  0.01968115  0.02071718  0.00322351 -0.00585919
  -0.00409283  0.00211513  0.00078388 -0.02968722]]
<NDArray 1x10 @cpu(0)>

后续的操作我需要将参数全部乘以一个数,于是我做了如下的操作:

number = 0.5
for param in trainer._params:
    param.list_data() * number
---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-4-9d21a053455d> in <module>
      1 number = 0.5
      2 for param in trainer._params:
----> 3     param.list_data() * number


TypeError: can't multiply sequence by non-int of type 'float'

此处提示序列不能乘以一个non-int类型,此时反应过来list_data()获取的是一个listlist乘以一个整数是将list进行复制多份,于是引出了标题的问题,如何将一个未知维度和大小的list中的元素乘以某一个数,于是我采取了下面的操作:

def multiply(parameterList, weightRatio):
    if isinstance(parameterList, list):
        for parameter in parameterList:
            multiply(parameter, weightRatio)
    else:
        parameterList *= weightRatio
for param in trainer._params:
    multiply(param.list_data(), number)

上面代码的原理是通过递归来进行操作,当传入的参数是一个list类型,那么就应该将每一个元素都进行扩大,否则直接将当前元素进行扩大即可。但是上面的解决方法并不优雅,后面发现可以通过numpy来实现该操作:

import numpy as np

l = [[1, 2, 3, 4], [2, 3, 4, 5]]
l = (np.array(l) * 0.5).tolist()
l
[[0.5, 1.0, 1.5, 2.0], [1.0, 1.5, 2.0, 2.5]]

但是发现这个操作并不适用list_data()获取的数据。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值