tensorflow保存、加载模型并预测数据

保存模型(ckpt)

仅需两行即可保存模型
saver = tf.train.Saver(tf.global_variables(), max_to_keep= 5)
#第二个参数填任意数字(用于区别各个保存的模型)
path = saver.save(sess, '../model/textCNN/model/my-model',global_step = currentStep)
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

注意:(保存模型需要自己先建立路径文件夹)

if not os.path.exists('../model/textCNN/model'):
    os.makedirs('../model/textCNN/model')#makedirs可以建立多层文件夹

加载模型

(调用ckpt查看模型地址)

#只需替换ckpt地址即可
ckpt = tf.train.get_checkpoint_state(r'C:\Users\Jaykie\Desktop\textClassifier\model\textCNN\model' + '/')#ckpt地址
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
    saver.restore(sess, ckpt.model_checkpoint_path)
    graph = tf.get_default_graph()
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

或者直接输入模型地址

new_saver = tf.train.import_meta_graph('%s.meta' % (parameters["mod_trained"]))
with tf.Session() as sess:
    new_saver.restore(sess, '%s' % (parameters["mod_trained"]))
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

配置feed_dict和输出结果

#tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]  # 得到当前图中所有变量的名称
#print(tensor_name_list)

#根据需要配置变量
_inputX = graph.get_tensor_by_name('inputX:0')
_dropoutKeepProb = graph.get_tensor_by_name('dropoutKeepProb:0')
feed_dict = {
    _inputX: trainReviews,
    _dropoutKeepProb: 1
}
#根据需要配置输出
y = graph.get_tensor_by_name('output/binaryPreds:0')
#run
predict = sess.run([y],feed_dict)

变量名根据计算图中的各个占位符的名称,用graph.get_tensor_by_name导出,记得加:0或者[0]。

如果未定义变量名称,可以通过tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node]来得到变量名称的列表

 

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值