tensorflow里的,保存和恢复模型的3种方法

 

 重点在于,第一个文件用于 训练,保存图meta和训练好的参数data(后缀),在另一个文件中导入这个图和训练好的参数,用于预测或者接着训练。大大减少了另一个文件里的 重复

1. 第一种情况是,产生变量的代码和恢复变量的代码在同一个文件时,可以直接如下调用:

# 建模型
saver = tf.train.Saver()
 
with tf.Session() as sess:
    # 存模型,注意此处的model是文件名,不是路径
    saver.save(sess, "/tmp/model")
 
with tf.Session() as sess:
    # 恢复模型
    saver.restore(sess, "/tmp/model")

2.第二种情况,不想在另一个文件中,把产生变量的 一大堆代码重敲一遍,可以直接从保存好的 meta文件和data文件中恢复出来

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2019/9/9 20:49
# @Author  : ZZL
# @File    : 保存检查点文件,并恢复.py
import tensorflow as tf
# Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.multiply(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
with tf.Session() as sess:
    with tf.device('/cpu:0'):
        sess.run(tf.global_variables_initializer())
        sess.run(vx.assign(tf.add(vx, vx)))
        result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
        print(result)
        print(saver.save(sess, "./model_ex1"))  # 该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给对“restore()”的调用。

 

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2019/9/9 20:54
# @Author  : ZZL
# @File    : 恢复文件.py
import  tensorflow as tf

saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

先来个空图,loaded_graph,在会话中,导入之前构建好的图的文件 后缀meta,loader.restore(sess, save_model_path)

在当前的loaded_graph中,导入构建好的图和图上的变量值。

def test_model():

    test_features, test_labels = pickle.load(open('preprocess_test.p', mode='rb'))
    loaded_graph = tf.Graph()  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
#     print( loaded_graph)
#     print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017C9A0C0C50>
    with tf.Session(graph=loaded_graph) as sess:
        # 读取模型
        loader = tf.train.import_meta_graph(save_model_path + '.meta')
        print(loader)
        loader.restore(sess, save_model_path)

        print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
        # 从已经读入的模型中 获取tensors 
        loaded_x = loaded_graph.get_tensor_by_name('x:0')
        loaded_y = loaded_graph.get_tensor_by_name('y:0')
        loaded_keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')
        loaded_logits = loaded_graph.get_tensor_by_name('logits:0')
        loaded_acc = loaded_graph.get_tensor_by_name('accuracy:0')
        
        # 获取每个batch的准确率,再求平均值,这样可以节约内存
        test_batch_acc_total = 0
        test_batch_count = 0
        
        for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size):
            test_batch_acc_total += sess.run(
                loaded_acc,
                feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0})
            test_batch_count += 1

参考:

https://blog.csdn.net/qq_16234613/article/details/83013436

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值