深度学习具有强大的特征表达能力。有时候我们训练好分类模型,并不想用来进行分类,而是用来提取特征用于其他任务,比如相似图片计算。接下来讲下如何使用TensorFlow提取特征。
1.必须在模型中命名好要提取的那一层,num_filters_total为提取特征维度,即特征个数,如下
net = _global_avg(net, pool_size=net.get_shape()[1:-1], strides=1)
net = tf.reshape(net, [-1, num_filters_total], name='reshape_feature')
2.通过调用sess.run()来获取reshape_feature层特征
feature = graph.get_operation_by_name("reshape_feature").outputs[0]
batch_predictions, batch_feature = \
sess.run([predictions, feature], {input_x: x_test_batch, dropout_keep_prob: 1.0}