有时候我们训练好了一个模型,效果还不错。那么如何保存这个模型,以便下次有新的数据时可以使用这个模型来进行预测呢?接下来我就以我的上一篇博客为基础进行模型的保存。
完整代码见Github
保存模型
具体模型训练的代码在上一篇博客讲得很清楚了,这次主要在原有的基础上进行改进。读取数据的代码没有变动。
在定义占位符的时候,加上一行命名代码,主要是为了方便我们在调用模型的时候可以准确找到模型的接口(这个后面具体会讲)
with tf.name_scope('input'): # 多加了这一句
xs = tf.placeholder(dtype='float', shape=[None, 4], name='xs')
ys = tf.placeholder(dtype='float', shape=[None, 3], name='ys')
同时,你需要模型的哪些中间结果都可以加上命名语句。比如我们想要模型输出:
with tf.name_scope('prediction'):
output = tf.nn.softmax(tf.matmul(xs, W) + b, name='output')
还需要一个准确率:
with tf.name_scope('accuracy'):
access = tf.equal(tf.argmax(output, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(access, "float"), name='accuracy')
其实保存模型的代码没有什么变化,主要是给每个tensorflow节点命名,方便我们寻找。模型训练完毕的时候就可以保存了。
print('--------------------训练结束--------------------\n\n')
print('************************性能评价************************')
print('训练集准确率:', sess.run(accuracy, {xs: x_train, ys: y_train}))
print('测试集准确率:', sess.run(accuracy, {xs: x_test, ys: y_test}))
saver = tf.train.Saver()
saver.save(sess, '../model/iris_model.ckpt')
print('模型已保存')
比之前就多了两行代码,saver = tf.train.Saver()
初始化函数,然后saver.save(sess, '你要保存的模型路径')
到这一步,模型就保存起来了。可以到你保存的模型文件夹下面看,会发现有四个文件
加载模型
模型保存好以后,接下来就可以使用我们训练好的模型去进行预测了。首先准备好我们的数据,格式要和之前训练模型的时候一样,这里我直接用测试集的数据
# 读取测试集数据
test = np.array(pd.read_csv('D:/tensorflow_exercise/data/iris_test.csv'))
x_test = test[:, 0:4]
rows = test.shape[0]
y_test = np.array(np.zeros([rows, 3]))
for r in range(rows):
label = int(test[r][4])
y_test[r][label] = 1
然后开始加载模型
sess = tf.Session() # 初始化session
# 加载模型
saver = tf.train.import_meta_graph('../model/iris_model.ckpt.meta') # 先加载meta文件,具体到文件名
saver.restore(sess, tf.train.latest_checkpoint('../model')) # 加载检查点文件checkpoint,具体到文件夹即可
graph = tf.get_default_graph() # 绘制tensorflow图
按照我的理解tf.get_default_graph()
相当于在一张纸上画了一个计算流程图,现在我们需要在图上找到输入输出口。
输入口就是我们之前定义的占位符
xs = graph.get_tensor_by_name('input/xs:0') # 获取占位符xs
ys = graph.get_tensor_by_name('input/ys:0') # 获取占位符ys
函数get_tensor_by_name
就是找到你之前定义的张量入口,还记得我们之前给张量命名吗?就是为了方便找到它,不然tensorflow会自动给它命名,你就不知道到底哪个张量对应哪个名字了。输出也是同样道理
out = graph.get_tensor_by_name(‘prediction/output:0’)
由于我们输出最后要进行准确率计算,而且我们之前定义的模型也有准确率的计算,因此我这里输出直接就是准确率,没有使用out的结果
acc = graph.get_tensor_by_name('accuracy/accuracy:0')
现在我们的输入输出口都有了,那么就可以开始喂数据进行预测了
print(sess.run(acc, feed_dict={xs:x_test, ys:y_test}))
到此,我们的整个任务就完成了~