Tensorflow Models

Tensorflow 模型读取保存等常用操作

1. 可视化

可视化网络结构图

  • 只需要在sess中加载好graph

  • writer = tf.summary.FileWriter('/home/wurui/dl_rpository/facenet_wz/LOG',sess.graph)
  • sess.run 后就能在LOG文件夹生成log

  • 在命令行中输入tensorboard --logdir '/home/wurui/dl_rpository/facenet_wz/LOG' 得到tensorboard端口,打开链接即可

2. Model 读取保存

1. frozen_graph.pb 读取

# 直接读取pb只能获得 graph_def 类型的图
def load_graph(filename):
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
    return graph_def

# 将 graph_def 赋给当前默认图,得到 graph 类型图
with session as sess:
    load model
    tf.import_graph_def(graph_def, name='')

# sess.graph > sess.graph_def 

2. frozen_graph.pb 保存

# 为了冻结所有参数,需要指定最终的节点
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,                             output_node_names=['embeddings'])
with tf.gfile.FastGFile(new_path_pb, mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

3. graph.pb的保存

tf.train.write_graph(sess.graph, '/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/', 'graph.pb')
# 只保存网络结构,等同于 .meta 文件

4. ckpt文件的读取

# 读取 .meta + .ckpt 或者 .meta + .ckpt.index + .ckpt.data 文件
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/ckpt/mobilenet.ckpt.meta')
    new_saver.restore(sess, "models/ckpt/mobilenet.ckpt")

5. ckpt 文件的保存

save_path = saver.save(sess, "models/ckpt/mobilenet.ckpt")
# 会得到 .checkpoint + .meta + .ckpt.index + .ckpt.data 四个文件

3. 模型操作

1. sess.run

with tf.Session(config=config) as sess:
    switch = sess.run([    tf.get_default_graph().get_tensor_by_name("InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/switch_t:0"),        tf.get_default_graph().get_tensor_by_name("InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/switch_t:0")], feed_dict=feed_dict)

2. 输出op/node名字

# load graph
saver = tf.train.import_meta_graph("/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/mobilenet_0.75_0.50_model-388003/model-388003.meta")
with tf.Session() as sess:
     saver.restore(sess, "/home/wurui/dl_rpository/Realtime_Multi-Person_Pose_Estimation-master/tf-openpose-master/models/mobilenet_0.75_0.50_model-388003/model-388003")
     # 输出op.names
     graph = tf.get_default_graph()
     for op in graph.get_operations():
         print op.name
         f.write(op.name+'\n')

3. 替换图中节点

def load_graph(filename):
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
    return graph_def

# load graph
graph_def = tf.Graph(old_graph)

# node.name 是编码型unicode不是str
str_a = 'phase_train'
target_node_name = unicode(str_a, "ascii")
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name)
# c = tf.constant(False, dtype=bool, shape=[], name='phase_train')

# 创建new grpah,从old graph获取值 
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
    if node.name == target_node_name:
        new_graph_def.node.extend([c.op.node_def])
    else:
        new_graph_def.node.extend([copy.deepcopy(node)])
# 主要删除训练中的节点 numeric nodes, 暂时没用
# cut_graph = graph_util_impl.remove_training_nodes(graph_def)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值