mxnet加载预训练

 

关乎symbol和module的一些基本属性

# 查看json每一个op的属性:kernel size、padding、stride等
sym.attr_dict() # 返回一个字典,根据key获取对应op的属性
# 查看网络的输出name
sym.list_outputs()
# 查看网络所有的输入节点name
sym.list_arguments()
# 查看网络所有内部节点
sym.get_internals()
# 获取网络的参数节点name
mod.get_params()[0]
# 获取网络的中间结果 fc7 output
all_layers = sym.get_internals()
sym = all_layers['fc7_output']
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) # 然后做一次inference就能获取fc7 output


原文链接:https://blog.csdn.net/wwwhp/article/details/84556909

 

模型:

参数: 

prefix: "mxnet/zwnwet_model", 
epoch:0

加载代码: 

  sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    print(sym)
    # print(arg_params)
    # print(aux_params)
 
    # 提取中间某层输出帖子特征层作为输出
    all_layers = sym.get_internals()
    print(all_layers)
    sym = all_layers['fc1_output']
 
    # 重建模型
    model = mx.mod.Module(symbol=sym, label_names=None)
    model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])
    model.set_params(arg_params, aux_params)

加载完毕保存模型


# !/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtuple

sym, arg_params, aux_params = mx.model.load_checkpoint('../new_model', 0)
# print(sym)


# 提取中间某层输出帖子特征层作为输出
all_layers = sym.get_internals()
# print(all_layers)
sym = all_layers['fc1_output']

# 重建模型
model = mx.mod.Module(symbol=sym, label_names=None)
model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])
model.set_params(arg_params, aux_params)

model.save_checkpoint("out/aaa",12)

获取指定层的输出
有些时候我们不需要网络的输出,而是只需要网络某个层的输出来通过网络提取图片的特征,这时候我们就需要指定提取层的名称,这里我们通过提取网络最后一层的全连接层为例



def get_specify_mod(model_str,ctx,data_shpae,layer_name):
    _vec = model_str.split(",")
    prefix = _vec[0]
    epoch = int(_vec[1])
    sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
    #获取神经网络所有的层
    all_layers = sym.get_internals()
    #获取输出层
    sym = all_layers[layer_name+"_output"]
    mod = mx.mod.Module(symbol=sym,context=ctx)
    mod.bind(data_shapes=[("data",data_shpae)])
    mod.set_params(arg_params,aux_params)
    return mod
    
def predict_specify(model_str,ctx,data_shape,img_path,label_path):
    label_names = get_label_names(label_path)
    #通过输出网络层的名称,输出层全连接层的名称为fc1
    mod = get_specify_mod(model_str,ctx,data_shape,layer_name="fc1")
    nd_img = preprocess_img(img_path,data_shape,ctx)
    #将需要预测的图片封装为Batch
    data_batch = mx.io.DataBatch(data=(nd_img,))
    #计算网络的预测值
    mod.forward(data_batch,is_train=False)
    #获取网络的输出值
    output = mod.get_outputs()[0]
    #对输出值进行softmax处理
    proba = mx.nd.softmax(output)
    #获取前top5的值
    top_proba = proba.topk(k=5)[0].asnumpy()
    for index in top_proba:
        probability = proba[0][int(index)].asscalar()*100
        pred_label_name = label_names[int(index)]
        print("label name=%s,probability=%f"%(pred_label_name,probability))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值