上一篇博文《TensorFlow模型参数的保存和加载》介绍了如何保存和加载TensorFlow模型训练参数,保存对象主要是Tensor/Variables。这一节我们介绍如何保存和复用op。
和Tensor一样,保存op需要在训练时为op指定名字,如下所示:
softmax = tf.nn.softmax(tf.matmul(x, W) + b,name="op_softmax")
在识别阶段,调用get_operation_by_name()函数,以op名字作为参数,如下所示:
op_softmax =sess.graph.get_operation_by_name("op_softmax").outputs[0]
这里需要注意的是,在最后面要加上“.outputs[0]”,否则会出现异常。
如果要直接运行sess.run(op_softmax),需要指定feed_dict。以官方mnist训练案例为例,调用格式为sess.run(op_softmax,feed_dict={x: mnist.test.images})。
在加载复用阶段,W和b的值已经保存在checkpoint数据中,故不需要再次声明W和b。但是,需要通过get_tensor_by_name()获取到x的声明,如下所示:
x = sess.graph.get_tensor_by_name("x:0")
<