'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
y = Net.inference_1(x, N_CLASS=5, train=False)
with tf.Session() as sess:
# 程序前面得有 Variable 供 save or restore 才不报错
# 否则会提示没有可保存的变量
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('./model/')
img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,'./model/model.ckpt-0')
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
res = sess.run(y, feed_dict={x: img})
print(global_step,sess.run(tf.argmax(res,1)))
2.加载图结构和参数
注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。
简化版本说明:
# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)
# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')
saver.restore(sess,ckpt.model_checkpoint_path)
2.TensorFlow二进制模型加载方法:
这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
# 新建GraphDef文件,用于临时载入模型中的图
graph_def = tf.GraphDef()
# GraphDef加载模型中的图
graph_def.ParseFromString(f.read())
# 在空白图中加载GraphDef中的图
tf.import_graph_def(graph_def,name='')
# 在图中获取张量需要使用graph.get_tensor_by_name加张量名
# 这里的张量可以直接用于session的run方法求值了
# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names