将tensorflow的ckpt模型转为pb文件, 需要知道网络的输出节点名称, 如果不指定输出节点名称, 程序就不知道该freeze哪些节点, 就没有办法保存模型.
获取ckpt模型中的节点名称
from tensorflow.python import pywrap_tensorflow
checkpoint_path = 'model.ckpt-xxx'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
将模型转为pb模型
import tensorflow as tf
def model(input):
net = tf.layers.conv2d(input,filters=32, kernel_size=3)
net = tf.layers.batch_normalization(net, fused=False)
net = tf.layers.separable_conv2d(net, 32, 3)
net = tf.layers.conv2d(net, filters=32, kernel_size=3, name='output')
return net
input_node = tf.placeholder(tf.float32, [1, 480, 480, 3], name='image')
pb = 'tftest.pb'
with tf.Session(