MXNet模型参数的延后初始化

模型的延后初始化

先看下面这段代码:

from mxnet import init, nd
from mxnet.gluon import nn

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        # 实际的初始化逻辑在此省略了
        
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'),
        nn.Dense(10))
    
net.initialize(init=MyInit())

这里虽然调用了initialize方法对模型net进行初始化,但是实际上初始化过程并没有进行

只有当我们做了一次前向传播模型才进行初始化

X = nd.random.uniform(shape=(2, 20))
Y = net(X)

并且这个初始化过程只在第一次前向计算的时候被调用,并且在第一次前向传播之前我们无法使用data函数和set_data函数来获取和修改参数。

避免延后初始化

1.对已经初始化的模型进行重新初始化的时候,因为参数形状不会发生变化,所以系统能够立即进行重新初始化

net.initialize(init=MyInit(), force_reinit=True)

2.在创建层的时候指定它的输入个数

net = nn.Sequential()
net.add(nn.Dense(256, in_units=20, activation='relu')) # in_units置顶输入个数
net.add(nn.Dense(10, in_units=256))

net.initialize(init=MyInit())

为什么模型要进行延后初始化呢?

答:主要是为了让模型的构造更加简单。例如,我们无须人工推测每个模型的输入个数(尤其层数多的时候,不推测直接写也很麻烦)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值