模型固化 - CANN 5.0.2 TensorFlow网络模型移植&训练指南 01 - 华为
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
Inference
'''
import logging
import os
import tensorflow as tf
# tf2 --> tf1
from tensorflow.python.tools import freeze_graph
tf.compat.v1.disable_v2_behavior()
from model import Transformer
from hparams_dh import Hparams
logging.basicConfig(level=logging.INFO)
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
path_now = os.getcwd()
ckpt_path = path_now + '/model/iwslt2016_E15L4.22-23445'
def main():
tf.compat.v1.reset_default_graph()
inputx = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, hp.maxlen1], name="inputx")
tranformer = Transformer(hp)
logits, y_hat, _ = tranformer.eval(inputx)
predict_class = tf.compat.v1.argmax(logits, axis=-1, name="output", output_type=tf.int32)
with tf.compat.v1.Session() as sess:
tf.io.write_graph(sess.graph_def, './pb_model', 'model.pb')
freeze_graph.freeze_graph(
input_graph='./pb_model/model.pb', # 传入write_graph生成的模型文件
input_saver='',
input_binary=False,
input_checkpoint=ckpt_path, # 传入训练生成的checkpoint文件
output_node_names='output', # 与定义的推理网络输出节点保持一致
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph='./pb_model/tansformer.pb', # 改为需要生成的推理网络的名称
clear_devices=False,
initializer_nodes=''
)
logging.info("Done")
if __name__== "__main__":
main()
pb模型用于预测
import tensorflow as tf
from tensorflow.compat.v1.train import NewCheckpointReader
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from create_tf_record import *
from tensorflow.python.framework import graph_util
from hparams_dh import Hparams
from data_utils import get_hypotheses, calc_bleu, postprocess, load_hparams
import logging
import os
from data_preprocess import Dataset
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
hparams = Hparams() # 参数
parser = hparams.parser
hp = parser.parse_args()
dataset = Dataset(hp.maxlen1, hp.maxlen2, hp.vocab, hp.batch_size)
def predict(pb_path, keys):
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
datas, _ = dataset.get_batch(keys)
# print(datas)
print(datas.shape)
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.compat.v1.Session() as sess:
inputx = sess.graph.get_tensor_by_name("inputx:0")
output_tensor_name = sess.graph.get_tensor_by_name("output:0")
predict = sess.run(output_tensor_name, feed_dict={inputx: datas})
# print(predict)
logging.info("# get hypotheses")
res = get_hypotheses(predict, dataset.idx2token)
return res