TensorFlow保存或加载训练的模型

什么是Tensorflow的模型

模型部分主要参考了这篇文章这篇博客;另外,官方文档也给出了很多指导。
Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:

  1. “.meta”:包含了计算图的结构
  2. “.data”:包含了变量的值
  3. “.index”:确认checkpoint
  4. “checkpiont”:一个protocol buffer,包含了最近的一些checkpoints

存储一个Tensorflow的模型

当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。

saver.save(sess,'my_test_model')

sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。
具体的实例:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()

执行上述语句,我们会同级目录下看到新增的文件:

my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta

如果网络架构更改了,Tensorflow会重写上述的文件。

如果我们想要每1000步保存一次,那么需要更改语句:

saver.save(sess, 'my_test_model', global_step=1000)

那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新创建.meta文件。 如果不写步数,默认每次迭代保存一次。

如果我们要仅仅保留最近4次创建的模型,并且每两个小时存储一次模型,可以进行下面的操作:

# saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

如果我们在tf.train.Saver()中不指定任何参数,那么Tensorflow会默认保存所有的变量。假设我们只想保留部分变量或者collection,那么需要显式地表明需要保留的对象。当创建tf.train.Saver()对象时,使用一个包含有关变量的list或者字典声明。比如:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1, w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()

导入一个训练好的模型

如果我们要导入一个训练好的模型,需要做以下两步:

创建一个网络

使用函数:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

把存储在my_test_model-1000.meta加载到saver当中。这个操作知识会把在.meta文件中定义的网络追加到当前网络的后面,我们仍然需要加载原来网络的参数数值。

加载参数

操作如下:

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    new_saver.restore(sess, tf.train.lasters_checkpoint('./'))

在这之后,w1w2的数据就会被重新加载进来。

对导入的模型进行的操作

现在,学着加载模型,把模型用于预测、训练甚至更改模型的架构。现在构造一个简单的网络模型,保存并重新导入。注意一点:tf.placeholder的数据不会被保存 !!!!
先定义训练文件:

import tensorflow as tf

# 定义用于恢复变量的例子
w1 = tf.placeholder(dtype=tf.float32, name="w1")
w2 = tf.placeholder(dtype=tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# 定义用于恢复操作的例子   w4=w3*b1,w3=(w1+w2)*b1
w3 = tf.add(w1, w2, name="part_op")
w4 = tf.multiply(w3, b1, name="op_to_restore")

sess = tf.Session()
sess.run(tf.global_variables_initializer())  # 时刻记着,要初始化

saver = tf.train.Saver()

print(sess.run(w4, feed_dict))  # 24.0

saver.save(sess, './my_test_model', global_step=1000)

sess.close()

定义加载文件:

import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# w4=w3*b1,w3=(w1+w2)*b1
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  # 60.0

print(sess.run(op_to_restore, feed_dict))

sess.close()

当导入模型的时候,不但需要恢复计算图和相关的参数,而且需要重新对tf.placeholder喂数据。通过graph.get_tensor_by_name获取保存的操作和占位符。如果我们想要使用网络计算,仅需要给不同的占位符添加不同的数据即可。

如果我们想要对原来的网络添加更多的层数并接着训练它,可以按照下面的步骤处理:

import tensorflow as tf

sess = tf.Session()
# 恢复计算图
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 获取占位符
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
# 恢复操作
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
# 增加新的操作
add_on_op = tf.multiply(op_to_restore, 2.0)
# 别忘了喂数据
print(sess.run(add_on_op, feed_dict))

sess.close()

由此可以看出,只需要把原来的操作加载完毕后,当成一个输出数据接入新的网络即可。

也可以把原来网络的一部分加载 到新的网络中,比如下面的操作:
先更改之前的一行代码

w3 = tf.add(w1, w2, name="part_op")

加载操作:

import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 14.0}

w3 = graph.get_tensor_by_name("part_op:0")

op = tf.multiply(w3, 4)
print(sess.run(op, feed_dict))  # 108.0
sess.close()

使用SavedModel的格式

SavedMode类把Saver类进行了一个更高层的封装,开发效率可能会更高,但是暂时没有前一种方法常用。Saver类更看重对变量的封装, 而SavedModel更看重压缩封装保存所有有用的信息。

保存操作:

import tensorflow as tf

tf.reset_default_graph()

w1 = tf.Variable(1.0, name="w1")
w2 = tf.Variable(2.0, name="w2")
w3 = tf.multiply(w1, w2, name="w3")

builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(w3)
    builder.add_meta_graph_and_variables(sess,
                                         [tf.saved_model.tag_constants.TRAINING],
                                         signature_def_map=None,
                                         assets_collection=None)
builder.save()

读取操作:

import tensorflow as tf

with tf.Session() as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING],
                               './SavedModel')

    w1 = sess.run('w1:0')
    w2 = sess.run('w2:0')
    w3 = sess.run('w3:0')

    print(w1, w2, w3)
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值