tensorflow模型封装 ckpt文件转换成pd文件

参考  https://blog.csdn.net/yjl9122/article/details/78341689

https://blog.csdn.net/guyuealian/article/details/82218092

 

 

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
# 本来这个model本无需解释太多,但是这么多人不能耐下心来看,那么我简单的说一下吧
# network是你们自己定义的模型结构而已
# ps:
# def network(input):
#    return tf.layers.max_pooling2d(input, 2, 2)
from model import network
   # 实际使用时要导入自己的模型文件


os.environ['CUDA_VISIBLE_DEVICES']='2'  #设置GPU cpu为默认时可以注释该行


model_path  = "path to /model.ckpt" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前


def main():

    tf.reset_default_graph()

    input_node = tf.placeholder(tf.float32, shape=(228, 304, 3)) #这个是你送入网络的图片大小,如果你是其他的大小自行修改
    input_node = tf.expand_dims(input_node, 0)
    flow_1 = network(input_node)
    flow = tf.cast(flow_1, tf.uint8, '
model_out') #设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用

    saver = tf.train.Saver()
    with tf.Session() as sess:

        saver.restore(sess, model_path)

        #保存图 下面红色为图的保存路径
        tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
        #把图和参数结构一起
        freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'model_out','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")

    print("done")

 

if __name__ == '__main__':
    main()

 

 

 参数说明:

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):
1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)
2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。
3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False
4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all
7、filename_tensor_name:(可选)已弃用。默认:save/Const:0
8、output_graph:(必选)用来保存整合后的模型输出文件。
9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。
 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值