我们知道了如何保存我们的模型接下来,我们就要想办法加载模型,调用模型,这也是我们用来做验证也好、做应用也好必须要做的。当然这里我们只考虑应用和验证,且只涉及模型部分,数据预处理,大家要自己加油啦。下一篇文章将为大家讲解如何加载别人的预训练模型进行微调(fintuning)。
上代码:
import tensorflow as tf
sess = tf.Session()
model_dir = 'xxx/xxx'
#图模型路径
meta_path = os.path.join(model_dir,'xxxx.meta')
#导入图结构,加载.meta文件
saver = tf.train.import_meta_graph(meta_path)
#恢复变量值,加载.ckpt文件
saver.restore(sess,tf.train.latest_checkpoint(model_dir))
#注意,这里我用的是tf.train.latest_checkpoint()函数,其作用是,返回最近一次保存的.ckpt数据文件,大家使用时完全可以手动赋值。
#获得默认图,即加载进来的图
graph = tf.get_default_graph()
#填充feed_dict
x = graph.get_tensor_by_name('input_images:0')#这里的input_images要替换成你的占位符
y = graph.get_tensor_by_name('input_labels:0')#这里的input_labels要替换成你的占位符
feed_dict={x:input_image,y:labels}
#通过图结构的名称来加载某结构的输出
conv1 = graph.get_tensor_by_name('conv1:0')#返回值是tensor
sess.run(conv1,feed_dict) #返回值为ndarry
注意:我们在使用get_tensor_by_name时,参数必须是:'层的名字:0',必须加“:0”。否则会报错:
ValueError: The name 'conv1' refers to an Operation, not a Tensor.
这是因为,conv1代表的是操作的名字,加上:0指代tensorname
完整项目有助于大家加深理解,链接:https://github.com/chenlinzhong/gender-recognition