MXNet学习 (1) :加载预训练模型

  • 首先在MXNet的model zoo下载对应的模型描述文件以及模型参数文件:
    • vgg16:对应vgg16.json vgg16-0000.params
    • resnet50:对应resnet50.json resnet50-0000.params
  • 加载网络结构设置网络运行runtime context:
import mxnet as mx
import numpy as np
from collections import namedtuple

net_name = 'vgg16'
img_name = 'dog.jpg'

# imagenet 图像预处理
def load_image(img_name):
    img = mx.image.imread(img_name)
    img = mx.image.imresize(img, 224, 224)
    img = img.transpose((2,0,1)) # hwc->chw
    img = img.expand_dims(axis=0) # chw->1chw
    img = img.astype('float32')
    return img

# 加载 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(net_name, 0) #net_name代表加载网络name,第二个参数代表Epoch num
# 设置 runtime context
ctx = mx.cpu() || ctx = mx.gpu()
  • 构造module用于执行symbol得到结果
mod = mx.mod.Module(symbol=sym, context=ctx)
mod.bind(for_training=False, data_shape[('data', (1, 3, 224, 224))]) # 为输入数据分配内存
mod.set_params(arg, aux) # 加载模型参数
Batch = namedtuple('Batch', ['data'])
img = load_image(img_name)
mod.forward(Batch([img])) # 做简单的inference
prob = mod.get_outputs()[0].asnumpy
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1] # 得到分类网络分类置信度的从大到小的结果
  • 关乎symbol和module的一些基本属性
# 查看json每一个op的属性:kernel size、padding、stride等
sym.attr_dict() # 返回一个字典,根据key获取对应op的属性
# 查看网络的输出name
sym.list_outputs()
# 查看网络所有的输入节点name
sym.list_arguments()
# 查看网络所有内部节点
sym.get_internals()
# 获取网络的参数节点name
mod.get_params()[0]
# 获取网络的中间结果 fc7 output
all_layers = sym.get_internals()
sym = all_layers['fc7_output']
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) # 然后做一次inference就能获取fc7 output
  • 遗留问题mod设置label_names=None时候,会提示一个warning目前不清楚怎么解决
//由于json里面有一个输入节点为softmax_label导致在做inference的时候总是会提示label_shapes的warninig,但是实际上在做inference的时候是不需要输入softmax_label的
Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值