首先在MXNet的model zoo 下载对应的模型描述文件以及模型参数文件:
vgg16:对应vgg16.json vgg16-0000.params resnet50:对应resnet50.json resnet50-0000.params 加载网络结构设置网络运行runtime context:
import mxnet as mx
import numpy as np
from collections import namedtuple
net_name = 'vgg16'
img_name = 'dog.jpg'
def load_image ( img_name) :
img = mx. image. imread( img_name)
img = mx. image. imresize( img, 224 , 224 )
img = img. transpose( ( 2 , 0 , 1 ) )
img = img. expand_dims( axis= 0 )
img = img. astype( 'float32' )
return img
sym, arg, aux = mx. model. load_checkpoint( net_name, 0 )
ctx = mx. cpu( ) | | ctx = mx. gpu( )
mod = mx. mod. Module( symbol= sym, context= ctx)
mod. bind( for_training= False , data_shape[ ( 'data' , ( 1 , 3 , 224 , 224 ) ) ] )
mod. set_params( arg, aux)
Batch = namedtuple( 'Batch' , [ 'data' ] )
img = load_image( img_name)
mod. forward( Batch( [ img] ) )
prob = mod. get_outputs( ) [ 0 ] . asnumpy
prob = np. squeeze( prob)
a = np. argsort( prob) [ : : - 1 ]
sym. attr_dict( )
sym. list_outputs( )
sym. list_arguments( )
sym. get_internals( )
mod. get_params( ) [ 0 ]
all_layers = sym. get_internals( )
sym = all_layers[ 'fc7_output' ]
mod = mx. mod. Module( symbol= sym, context= ctx, label_names= None )
遗留问题mod设置label_names=None时候,会提示一个warning目前不清楚怎么解决
Data provided by label_shapes don't match names specified by label_names ([] vs. [' softmax_label'] )