在网上下载的MXNet预训练模型常常是完整的,但是在实际应用中,我们一般只需要网络中某一层作为特征提取,这个时候就需要重建模型,使得网络最后的输出是特征。
加载预训练模型
加载模型使用model.FeedForward.load
就可以了,后面的参数分别是模型的名字、迭代次数和batch大小,需要根据实际模型进行修改:
import mxnet as mx
import numpy as np
model=mx.model.FeedForward.load('model_name',1,num_batch_size=1)
找到特征层
别人训练好的模型我们常常不知道有哪些层,这时候需要列出所有的层,以便于我们找到特征层:
internals=model.symbol.get_internals() #list all symbol
internals.list_outputs()
列出网络中所有的层&