Tensorflow 模型读取保存等常用操作
1. 可视化
可视化网络结构图
只需要在sess中加载好graph
writer = tf.summary.FileWriter('/home/wurui/dl_rpository/facenet_wz/LOG',sess.graph)
sess.run
后就能在LOG
文件夹生成log在命令行中输入
tensorboard --logdir '/home/wurui/dl_rpository/facenet_wz/LOG'
得到tensorboard端口,打开链接即可
2. Model 读取保存
1. frozen_graph.pb 读取
# 直接读取pb只能获得 graph_def 类型的图
def load_graph(filename):
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def.ParseFromString(f.read())
return graph_def
# 将 graph_def 赋给当前默认图,得到 graph 类型图
with session as sess:
load model
tf.import_graph_def(graph_def, name='')
# sess.graph > sess.graph_def
2. frozen_graph.pb 保存
# 为了冻结所有参数,需要指定最终的节点
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['embeddings'])
with tf.gfile.FastGFile(new_path_pb, mode='wb') as f:
f.write(output_graph_def.SerializeToString())
3. graph.pb的保存
tf.train.write_graph(sess.graph, '/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/', 'graph.pb')
# 只保存网络结构,等同于 .meta 文件
4. ckpt文件的读取
# 读取 .meta + .ckpt 或者 .meta + .ckpt.index + .ckpt.data 文件
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/ckpt/mobilenet.ckpt.meta')
new_saver.restore(sess, "models/ckpt/mobilenet.ckpt")
5. ckpt 文件的保存
save_path = saver.save(sess, "models/ckpt/mobilenet.ckpt")
# 会得到 .checkpoint + .meta + .ckpt.index + .ckpt.data 四个文件
3. 模型操作
1. sess.run
with tf.Session(config=config) as sess:
switch = sess.run([ tf.get_default_graph().get_tensor_by_name("InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/switch_t:0"), tf.get_default_graph().get_tensor_by_name("InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/switch_t:0")], feed_dict=feed_dict)
2. 输出op/node名字
# load graph
saver = tf.train.import_meta_graph("/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/mobilenet_0.75_0.50_model-388003/model-388003.meta")
with tf.Session() as sess:
saver.restore(sess, "/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/mobilenet_0.75_0.50_model-388003/model-388003")
# 输出op.names
graph = tf.get_default_graph()
for op in graph.get_operations():
print op.name
f.write(op.name+'\n')
3. 替换图中节点
def load_graph(filename):
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def.ParseFromString(f.read())
return graph_def
# load graph
graph_def = tf.Graph(old_graph)
# node.name 是编码型unicode不是str
str_a = 'phase_train'
target_node_name = unicode(str_a, "ascii")
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name)
# c = tf.constant(False, dtype=bool, shape=[], name='phase_train')
# 创建new grpah,从old graph获取值
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
if node.name == target_node_name:
new_graph_def.node.extend([c.op.node_def])
else:
new_graph_def.node.extend([copy.deepcopy(node)])
# 主要删除训练中的节点 numeric nodes, 暂时没用
# cut_graph = graph_util_impl.remove_training_nodes(graph_def)