mxnet如何将.json、.params模型转换为gluon模型

Gluon是MXNet的动态图接口;Gluon学习了Keras,Chainer,和Pytorch的优点,并加以改进。接口更简单,且支持动态图(Imperative)编程。相比TF,Caffe2等静态图(Symbolic)框架更加灵活易用。同时Gluon还继承了MXNet速度快,省显存,并行效率高的优点,并支持静、动态图混用,比Pytorch更快。——转自解浚源知乎

题目中提及的.json、.params模型如下所示:

 mxnet版本高于1.2.1可以使用如下方法:

net = gluon.nn.SymbolBlock.imports('resnet18-symbol.json',
                                   ['data'], 
                                   param_file='resnet18-0000.params',
                                   ctx=mx.gpu())

以前的版本可以使用:

sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18', 0)

net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
# Set the params
net_params = net.collect_params()
for param in arg_params:
    if param in net_params:
        net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
    if param in net_params:
        net_params[param]._load_init(aux_params[param], ctx=ctx)

net_params是ParameterDict类型,也就是value为Parameter类型的字典,其可以通过data()函数获得其具体参数,参数类型为NDArray,如:

arraya = net_params['stage4_unit3_bn2_beta'].data()

arg_params,aux_params均是一个字典类型,他们的结构均为"参数名称":NDarray,如:

arrayb = arg_params['stage4_unit3_bn2_beta']

 需要说明的是:

inputs=mx.sym.var('data')

是使用静态图的方法生成一个输入节点名为'data' ,arg_params是主要参数如weights,aux_params是辅助参数主要是bias或者是batchnorm中的一些参数。

疑问:以上方法是对模型中的参数一个个的load,虽然已经装载进去了但是net_params内部的参数的shape扔然是None这是不解的地方,如下所示;

 疑问5.29日解决,原因是:虽然模型函数已经加载了参数,但是mxnet模型推断机制是在模型进行一次前向计算(forward)后才完成,如下图所示:

SymbolBlock是继承于block有好多的Sequence的方法,其并不能使用,如net[0]因为其内部并没有__getitems__()函数所以这种访问模型内部参数字典的方法并不适用

第一种方法的import函数内容如下:

    def imports(symbol_file, input_names, param_file=None, ctx=None):
        """Import model previously saved by `HybridBlock.export` or
        `Module.save_checkpoint` as a SymbolBlock for use in Gluon.

        Parameters
        ----------
        symbol_file : str
            Path to symbol file.
        input_names : list of str
            List of input variable names
        param_file : str, optional
            Path to parameter file.
        ctx : Context, default None
            The context to initialize SymbolBlock on.

        Returns
        -------
        SymbolBlock
            SymbolBlock loaded from symbol and parameter files.

        Examples
        --------
        >>> net1 = gluon.model_zoo.vision.resnet18_v1(
        ...     prefix='resnet', pretrained=True)
        >>> net1.hybridize()
        >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
        >>> out1 = net1(x)
        >>> net1.export('net1', epoch=1)
        >>>
        >>> net2 = gluon.SymbolBlock.imports(
        ...     'net1-symbol.json', ['data'], 'net1-0001.params')
        >>> out2 = net2(x)
        """
        sym = symbol.load(symbol_file)
        if isinstance(input_names, str):
            input_names = [input_names]
        inputs = [symbol.var(i) for i in input_names]
        ret = SymbolBlock(sym, inputs)
        if param_file is not None:
            ret.collect_params().load(param_file, ctx=ctx)
        return ret

特殊情况下可以用到

def load_model_finetune(model_path, ctx, output_layer=None, expect_prefix=None, is_train=True):
    model_name, epoch = model_path.split(SLASH)[-1].split('-')
    model_path = os.path.join(SLASH, *model_path.split(SLASH)[0:-1], model_name)
    syms, arg_params, aux_params = mx.model.load_checkpoint(model_path, int(epoch))
    if output_layer is not None:
        all_layers = syms.get_internals()
        syms = all_layers[output_layer]
    net = mx.gluon.nn.SymbolBlock(outputs=syms, inputs=mx.sym.var('data'))
    net_params = net.collect_params()
    print(net_params.keys())
    for param in arg_params:
        if param in net_params:
            if expect_prefix is not None:
                if expect_prefix in param:
                    continue
            net_params[param]._load_init(arg_params[param], ctx=ctx)
            if not is_train:
                net_params[param].setattr("grad_req", "null")

    for param in aux_params:
        if param in net_params:
            if expect_prefix is not None:
                if expect_prefix in param:
                    continue
            net_params[param]._load_init(aux_params[param], ctx=ctx)
            if not is_train:
                net_params[param].setattr("grad_req", "null")

    return net

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值