tensorflow保存训练好的模型以及使用模型

有时候我们训练好了一个模型,效果还不错。那么如何保存这个模型,以便下次有新的数据时可以使用这个模型来进行预测呢?接下来我就以我的上一篇博客为基础进行模型的保存。
完整代码见Github

保存模型

具体模型训练的代码在上一篇博客讲得很清楚了,这次主要在原有的基础上进行改进。读取数据的代码没有变动。
在定义占位符的时候,加上一行命名代码,主要是为了方便我们在调用模型的时候可以准确找到模型的接口(这个后面具体会讲)

with tf.name_scope('input'):  # 多加了这一句
    xs = tf.placeholder(dtype='float', shape=[None, 4], name='xs')
    ys = tf.placeholder(dtype='float', shape=[None, 3], name='ys')

同时,你需要模型的哪些中间结果都可以加上命名语句。比如我们想要模型输出:

with tf.name_scope('prediction'):
    output = tf.nn.softmax(tf.matmul(xs, W) + b, name='output')

还需要一个准确率:

with tf.name_scope('accuracy'):
    access = tf.equal(tf.argmax(output, 1), tf.argmax(ys, 1))
    accuracy = tf.reduce_mean(tf.cast(access, "float"), name='accuracy')

其实保存模型的代码没有什么变化,主要是给每个tensorflow节点命名,方便我们寻找。模型训练完毕的时候就可以保存了。

print('--------------------训练结束--------------------\n\n')
print('************************性能评价************************')
print('训练集准确率:', sess.run(accuracy, {xs: x_train, ys: y_train}))
print('测试集准确率:', sess.run(accuracy, {xs: x_test, ys: y_test}))
saver = tf.train.Saver()
saver.save(sess, '../model/iris_model.ckpt')
print('模型已保存')

比之前就多了两行代码,saver = tf.train.Saver()初始化函数,然后saver.save(sess, '你要保存的模型路径')
到这一步,模型就保存起来了。可以到你保存的模型文件夹下面看,会发现有四个文件
在这里插入图片描述

加载模型

模型保存好以后,接下来就可以使用我们训练好的模型去进行预测了。首先准备好我们的数据,格式要和之前训练模型的时候一样,这里我直接用测试集的数据

# 读取测试集数据
test = np.array(pd.read_csv('D:/tensorflow_exercise/data/iris_test.csv'))
x_test = test[:, 0:4]
rows = test.shape[0]
y_test = np.array(np.zeros([rows, 3]))
for r in range(rows):
    label = int(test[r][4])
    y_test[r][label] = 1

然后开始加载模型

sess = tf.Session()  # 初始化session
# 加载模型
saver = tf.train.import_meta_graph('../model/iris_model.ckpt.meta')  # 先加载meta文件,具体到文件名
saver.restore(sess, tf.train.latest_checkpoint('../model'))  # 加载检查点文件checkpoint,具体到文件夹即可
graph = tf.get_default_graph()  # 绘制tensorflow图

按照我的理解tf.get_default_graph()相当于在一张纸上画了一个计算流程图,现在我们需要在图上找到输入输出口。
输入口就是我们之前定义的占位符

xs = graph.get_tensor_by_name('input/xs:0')  # 获取占位符xs
ys = graph.get_tensor_by_name('input/ys:0')  # 获取占位符ys

函数get_tensor_by_name就是找到你之前定义的张量入口,还记得我们之前给张量命名吗?就是为了方便找到它,不然tensorflow会自动给它命名,你就不知道到底哪个张量对应哪个名字了。输出也是同样道理
out = graph.get_tensor_by_name(‘prediction/output:0’)
由于我们输出最后要进行准确率计算,而且我们之前定义的模型也有准确率的计算,因此我这里输出直接就是准确率,没有使用out的结果

acc = graph.get_tensor_by_name('accuracy/accuracy:0')

现在我们的输入输出口都有了,那么就可以开始喂数据进行预测了

print(sess.run(acc, feed_dict={xs:x_test, ys:y_test}))

到此,我们的整个任务就完成了~

  • 12
    点赞
  • 109
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值