tensorflow保存和恢复模型

以下代码在Python3.6和TensorFlow>=1.10运行通过。

1.tensorflow模型有两个文件组成:

(1)meta graph:
这是一个协议缓冲区, 它保存了完整的tensorflow图形,即所有变量、操作、集合等。该文件以.meta作为扩展名。
(2)checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏置、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。Tensorflow从0.11版本中改变了这一点。现在,我们有两个文件,而不是单个.ckpt文件:
mymodel.data-00000-of-00001
mymodel.index
.data文件保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系。与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。

2. 保存TensorFlow模型

在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。

saver = tf.train.Saver()

Tensorflow变量仅在会话中存在。因此,您必须在一个会话中保存模型,调用您刚刚创建的save方法。

saver.save(sess, “./my-test-model")

看一个例子:

import tensorflow as tf

w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

print(sess.run(w4, feed_dict=feed_dict))

saver.save(sess, "./my_test_model", global_step=1000)

如果我们在tf.train.Saver()中没有指定任何东西,它将保存所有的变量。如果,我们不想保存所有的变量,而只是一些变量。我们可以指定要保存的变量/集合。在创建tf.train。保护程序实例,我们将它传递给我们想要保存的变量的列表或字典。让我们来看一个例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1, w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, "./my_test_model",global_step=1000)

这可以用于保存一些特定的参数

3.导入训练好的模型

如果你想用别人预先训练好的模型来进行微调,你需要做以下两件事:
(1)创建网络
我们已经在.meta文件中保存了这个网络,我们可以使用tf.train.import()函数来重新创建这个网络:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

记住,import_meta_graph将在.meta文件中定义的网络附加到当前图。因此,这将为你创建图形/网络,但是我们仍然需要加载我们在这张图上训练过的参数的值。
(2)载入参数
我们可以通过调用这个保护程序的实例来恢复网络的参数,它是tf.train.Saver()类的一个实例。

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))

因此,现在你已经了解了如何为Tensorflow模型保存和导入工作。

4.实例-使用保存好的模型

让我们开发一个实用的例子来恢复任何预先训练的模型,并使用它进行预测、微调或进一步训练。你将定义一个图,该图是feed examples(训练数据)和一些超参数(如学习速率、迭代次数等),它是一个标准的过程,我们可以使用占位符来存放所有的训练数据和超参数。接下来,让我们使用占位符构建一个小网络并保存它。注意,当网络被保存时,占位符的值不会被保存。

如果我们只是想用不同的数据运行相同的网络,您可以简单地通过feed_dict将新数据传递给网络。

import tensorflow as tf

w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

print(sess.run(w4, feed_dict=feed_dict))

saver.save(sess, "./my_test_model", global_step=1000)

sess = tf.Session()
saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore, feed_dict=feed_dict))

如果你希望通过添加更多的层数并对其进行训练,从而向图中添加更多的操作,可以这样做

import tensorflow as tf

w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

print(sess.run(w4, feed_dict=feed_dict))

saver.save(sess, "./my_test_model", global_step=1000)

sess = tf.Session()
saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
add_on_op = tf.multiply(op_to_restore, 2)
print(sess.run(add_on_op, feed_dict=feed_dict))

关于上面代码中按张量名获取张量中的("w1:0"),如果改成("w1"),则会报错:ValueError: The name 'w1' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".

参考:
https://blog.csdn.net/tan_handsome/article/details/79303269

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值