由于slim库不是tf的核心库,因此需要到github下载相关代码,这里假设我的工作目录为:/home/hiptonese/MigrationLearning
- 1 下载代码:https://github.com/tensorflow/models
- 2 将下载好的代码放到工作目录下
- 3 下载你所需要的模型的checkpoint文件(该文件存放了模型预训练的变量值),这里列出了各个常用模型的ckpt文件:https://github.com/tensorflow/models/tree/master/research/slim#Pretrained
- 4 加载代码和图片文件,这里给出例子:
'''
@Date : 2017-11-21 19:18
@Author: yangyang Deng
@Email : yangydeng@163.com
'''
import os
import tensorflow as tf
from models.research.slim.datasets import imagenet
from models.research.slim.preprocessing import inception_preprocessing
import numpy as np
# 工程的根目录,同时也是ckpt所在的目录
checkpoints_dir = '/home/hiptonese/MigrationLearning/'
slim = tf.contrib.slim
image_size = 299
with tf.Graph().as_default():
with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
# 加载一张图片
imgPath = 'ship.jpeg'
testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
testImage = tf.image.decode_jpeg(testImage_string, channels=3)
processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
# 这里如果我们设置num_classes=None,则可以得到restnet输出的瓶颈层,num_classes默认为10001,是用作imagenet的输出层。同样,我们也可以根据需要修改num_classes为其他的值来满足我们的训练要求。
final_point, endpoints = inception_resnet_v2.inception_resnet_v2(processed_images, num_classes=None, is_training=False)
init_fn = slim.assign_from_checkpoint_fn(os.path.join(checkpoints_dir, 'inception_resnet_v2_2016_08_30.ckpt'),slim.get_model_variables('InceptionResnetV2'))
with tf.Session() as sess:
init_fn(sess)
final_point_eval = np.array(sess.run(final_point))
print(final_point_eval.shape)
× 最后解释一下“瓶颈层”(bottle neck layer)的含义:
瓶颈层一般指网络结束卷基层,将要进入全连层的输入。由于网络中的变量已经做了预训练,因此瓶颈层的输出可以看做是对原始图片的进一步特征提取。因此这里如果将瓶颈层作为输入,后面只需要自己加入FC全连层,则可以不在参数调整和训练上花太多时间,快速达到较好的效果。