这篇文章为说明将训练好的模型保存为ckpt后,再生成PB文件,并调用的方法。其中,生成和调用的部分代码是统一的格式,可以收藏。训练图用的是前一篇博文泰坦尼克生存率预测,可以参考:https://blog.csdn.net/aspirinLi/article/details/105350452。由于这个图的输入和输出都较为简单,测试模型部分也简化了,图像的输入还需要找机会尝试。
一、首先将训练好的模型保存为检查点
1、在Session()前写上代码:
saver = tf.train.Saver()
2、在Session()结束时写上:
saver.save(sess,'./model/model.ckpt',global_step=i+1)
以下是预测泰坦尼克代码的部分:
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
loss_train = []
train_acc = []
test_acc = []
sess.run(init)
for i in range(10000):
for n in range(len(data_target)//100+1):
batch_xs = data_train[n*100:n*100+100]
batch_ys = data_target[n * 100:n * 100 + 100]
sess.run(step,feed_dict={x:batch_xs,y:batch_ys})
if i%1000 == 0:
loss_temp = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
loss_train.append(loss_temp)
train_acc_temp = sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys})
train_acc.append(train_acc_temp)
test_acc_temp = sess.run(accuracy,feed_dict={x:data_test,y:test_label})
test_acc.append(test_acc_temp)
print(loss_temp,train_acc_temp,test_acc_temp)
saver.save(sess,'./model/model.ckpt',global_step=i+1)
泰坦尼克完整代码见:https://blog.csdn.net/aspirinLi/article/details/105350452
于是,我们就可以边训练边保存模型,完成后生成下列文件:
二、生成PB文件
首先提示一下,可以将需要固化的值在训练的时候就给出名字,这样取值会方便。比如:
x = tf.placeholder(shape=[None,12],dtype = tf.float32,name='input_x')
output = tf.add(tf.matmul(x,weight),bias,name = 'output')
接下来生成PB文件:
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
def freeze_graph(input_checkpoint, output_graph):
#这里选择需要固化的量
output_node_names = "input_x,output"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
# 输入ckpt模型路径
input_checkpoint='D:/TensorflowTest/taitanic/model/model.ckpt-9001'
# 输出pb模型的路径
out_pb_path="D:/TensorflowTest/taitanic/model/frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)
三、调用PB文件
这里选择的测试文件是test.csv。前面的处理参考泰坦尼克预测存活率那篇博文。
import tensorflow as tf
import pandas as pd
import os
import numpy as np
from tensorflow.python.platform import gfile
# 先检测看pb文件是否存在
savePbFile = './model/frozen_model.pb'
if os.path.exists(savePbFile) is False:
print('Not found pb file!')
exit()
data_test = pd.read_csv('./data/test.csv')
data_test = data_test[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Cabin', 'Embarked']]
data_test['Age'] = data_test['Age'].fillna(data_test['Age'].mean())
data_test['Cabin'] = pd.factorize(data_test.Cabin)[0]
data_test.fillna(0, inplace=True)
data_test['Sex'] = [1 if x == 'male' else 0 for x in data_test.Sex]
data_test['p1'] = np.array(data_test['Pclass'] == 1).astype(np.int32)
data_test['p2'] = np.array(data_test['Pclass'] == 2).astype(np.int32)
data_test['p3'] = np.array(data_test['Pclass'] == 3).astype(np.int32)
del data_test['Pclass']
data_test['e1'] = np.array(data_test['Embarked'] == 'S').astype(np.int32)
data_test['e2'] = np.array(data_test['Embarked'] == 'C').astype(np.int32)
data_test['e3'] = np.array(data_test['Embarked'] == 'Q').astype(np.int32)
del data_test['Embarked']
#print(data_test)
#data_test = data_test.astype()
gender = pd.read_csv('./data/gender.csv')
gender = np.reshape(gender.Survived.values.astype(np.float32),(418,1))
with tf.Session() as sess:
# 打开pb模型文件
with gfile.FastGFile(savePbFile, 'rb') as fd:
# 导入图
graph_def = tf.GraphDef()
graph_def.ParseFromString(fd.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# 根据名字获取对应的tensorflow
input = sess.graph.get_tensor_by_name('input_x:0')
output = sess.graph.get_tensor_by_name('output:0')
output_e = tf.cast(tf.sigmoid(output)>0.5,tf.float32)
acc = tf.reduce_mean(tf.cast(tf.equal(gender, output_e),dtype = tf.float32))
result = sess.run(output_e, feed_dict={input:data_test})
accuracy = sess.run(acc,feed_dict={input:data_test})
print(result)
print(accuracy)
即可。