tensoflow的模型保存与恢复

tensorflow新建的模型,保存之后会有四个文件:

checkpoint       # 用来索引最新的模型文件,模型文件包括以下三种

name.data-00000-of-00001    # 保存模型参数

name.index    # 保存模型参数

name.meta   # 用以保存模型的图结构


在训练的过程中,我们期待一直保存我们的模型,新保存的模型不会立马覆盖之前的模型,所以需要checkpoint来引导,这样恢复模型的时候可以恢复最新的模型。

所以,一个模型文件包含3个文件:其中meta后缀名的文件会保存模型的计算图结构,其他两个文件会保存参数和变量。

先新建一个模型:

import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="result")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print(sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, 'checkpoint\\my_test_model', global_step=1000)

我们实现的是(w1+w2)*b1的计算操作,首先用placeholder来确定输入接口,意思是w1和w2是我的输入入口,这个输入入口直接和feed_dict对应。描述完计算图之后,就可以新建一个会话,会话就是对接工作,是执行计算的入口。

最后,把会话存起来: saver.save(sess, 'checkpoint\\my_test_model', global_step=1000)

sess表示会话名,'checkpoint\\my_test_model'表示保存路径,global_step会在保存路径后面加上-1000,做一个命名区分。

保存后的模型如下:

my_test_model-1000.index

my_test_model-1000.meta

my_test_model-1000.data-00000-of-00001

checkpoint


恢复模型

import tensorflow as tf

with tf.Session() as sess:
    # First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('checkpoint\\my_test_model-1000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('checkpoint/'))

    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data

    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict = {w1: 13.0, w2: 13}

    # Now, access the op that you want to run.
    op_to_restore = graph.get_tensor_by_name("result:0")

    print(sess.run(op_to_restore, feed_dict))
    # This will print 60 which is calculated
    # using new values of w1 and w2 and saved value of b1.

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木盏

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值