在使用Tensorflow做读取并finetune的时候,发现在读取官方给的inception_v3预训练模型总是出现各种错误,现记录其正确的读取方式和各种错误做法:
关键代码如下:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
.....................................................
# 读取网络
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)
....................................................
with tf.Session() as sess:
# 先初始化所有变量,避免有些变量未读取而产生错误
init = tf.global_variables_initializer()
sess.run(init)
#加载预训练模型
print('Loading model check point from {:s}'.format(Pretrained_model_dir))
#这里的exclusions是不需要读取预训练模型中的Logits,因为默认的类别数目是1000,当你的类别数目不是1000的时候,如果还要读取的话,就会报错
exclusions = ['InceptionV3/Logits',
'InceptionV3/AuxLogits']
#创建一个列表,包含除了exclusions之外所有需要读取的变量
inception_exc