tensorflow ckpt 模型固化为pb 注意事项

Wiki - Gitee.com

模型固化 - CANN 5.0.2 TensorFlow网络模型移植&训练指南 01 - 华为

# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer

Inference
'''
import logging
import os

import tensorflow as tf
# tf2 --> tf1
from tensorflow.python.tools import freeze_graph

tf.compat.v1.disable_v2_behavior()

from model import Transformer
from hparams_dh import Hparams

logging.basicConfig(level=logging.INFO)
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()

path_now = os.getcwd()
ckpt_path = path_now + '/model/iwslt2016_E15L4.22-23445'
def main():
    tf.compat.v1.reset_default_graph()
    inputx = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, hp.maxlen1], name="inputx")
    tranformer = Transformer(hp)
    logits, y_hat, _ = tranformer.eval(inputx)
    predict_class = tf.compat.v1.argmax(logits, axis=-1, name="output", output_type=tf.int32)
    with tf.compat.v1.Session() as sess:
        tf.io.write_graph(sess.graph_def, './pb_model', 'model.pb')
        freeze_graph.freeze_graph(
            input_graph='./pb_model/model.pb',  # 传入write_graph生成的模型文件
            input_saver='',
            input_binary=False,
            input_checkpoint=ckpt_path,  # 传入训练生成的checkpoint文件
            output_node_names='output',  # 与定义的推理网络输出节点保持一致
            restore_op_name='save/restore_all',
            filename_tensor_name='save/Const:0',
            output_graph='./pb_model/tansformer.pb',  # 改为需要生成的推理网络的名称
            clear_devices=False,
            initializer_nodes=''
        )
    logging.info("Done")

if __name__== "__main__":
    main()


pb模型用于预测

import tensorflow as tf
from tensorflow.compat.v1.train import NewCheckpointReader
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from create_tf_record import *
from tensorflow.python.framework import graph_util
from hparams_dh import Hparams
from data_utils import get_hypotheses, calc_bleu, postprocess, load_hparams
import logging
import os
from data_preprocess import Dataset
os.environ["CUDA_VISIBLE_DEVICES"] = "2"



hparams = Hparams()   # 参数
parser = hparams.parser
hp = parser.parse_args()
dataset = Dataset(hp.maxlen1, hp.maxlen2, hp.vocab, hp.batch_size)

def predict(pb_path, keys):
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    datas, _ = dataset.get_batch(keys)
    # print(datas)
    print(datas.shape)


    with tf.Graph().as_default():
        output_graph_def = tf.compat.v1.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.compat.v1.Session() as sess:

            inputx = sess.graph.get_tensor_by_name("inputx:0")

            output_tensor_name = sess.graph.get_tensor_by_name("output:0")
            
            predict = sess.run(output_tensor_name, feed_dict={inputx: datas})
            # print(predict)
            logging.info("# get hypotheses")
            res = get_hypotheses(predict, dataset.idx2token)

        return res

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值