tensorflow 中checkpoint详解

本文详细介绍了TensorFlow中checkpoint的使用,包括文件结构、meta文件、data文件和index文件的作用。讨论了如何在训练过程中定期保存模型,以及如何在后续训练或推理时导入模型,包括恢复网络结构和变量参数。同时,文章提到了如何从ckpt文件恢复并继续训练,以及如何在恢复后修改模型结构。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. checkpoint(*.ckpt)

Checkpoint是用于描述在每次训练后保存模型参数(权重)的惯例或术语。这就像在游戏中保存关卡时你可以随时通过加载保存文件回复游戏。检查点checkpoint中存储着模型model所使用的的所有的 tf.Variable 对象,它不包含任何关于模型的计算信息,因此只有在源代码可用,也就是我们可以恢复原模型结构的时候,checkpoint才有用,否则不知道模型的结构,仅仅只知道一些Variable是没有意义的。

1.1 文件结构介绍

—checkpoint

—model.ckpt-240000.data-00000-of-00001

—model.ckpt-240000.index

—model.ckpt-240000.meta

1.2 meta文件(保存了tensorflow完整的网络结构)

.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection;这是我们恢复模型结构的参照;
meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。当然在使用低层PAI编写神经网络的时候,本质上是一系列运算以及张量构造的一个较为复杂的graph,这个和高层API中的层的概念还是有区别的,但是可以这么去理解,整个graph的结构就是网络结构。一般而言网络结构是不会发生改变,所以可以只保存一个就行了。我们可以使用下面的代码只在第一次保存meta文件。

saver.save(sess, 'my_model.ckpt', global_step=step, write_meta_graph=False)

在后面恢复整个graph的结构的时候,并且还可以使用

tf.train.import_meta_graph(‘xxxxxx.meta’)

1.3 data文件

model.ckpt-240000.data-00000-of-00001:数据文件,保存的是网络的权值,偏置,操作等

1.4 index文件

model.ckpt-240000.index :是一个不可变得字符串字典,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。用于描述variable中key和value的关系。 每个BundleEntryProto描述张量的元数据,所谓的元数据就是描述这个Variable 的一些信息的数据。 “数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。

2. 导出与导入

2.1 导出ckpt

在结束训练后,把所有的变量和网络结构保存下来使用tf.train.Saver,eg:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

如果迭代1000次以后保存模型,可以把当前迭代次数存入

saver.save(sess,'model',gloal_step=1000)

训练的时候,假设1000次就保存一次模型,但上述三个文件中仅改变网络的参数,网络结构不发生变化,无需重复保存.meta文件,我们可以设置只保存网络结构一次

saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

如果只想保留最新的四个模型,并且两个小时保存一次,可以使用max_to_keep和keep_checkpoint_every_n_hours:

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

如果没有在tf.train.Saver()指定任何参数,这样表示默认保存所有变量。如果我们不希望保存所有变量,而只是其中的一部分,此时我们可以指点要保存的变量或者集合:我们只需在创建tf.train.Saver的时候把一个列表或者要保存变量的字典作为参数传进去。

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver.save(sess, 'my_test_model',global_step=1000)

2.1 导入ckpt

  1. 从meta文件导入原始网络结构图
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  1. 加载变量
    使用restore()方法恢复模型的变量参数
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model')
new_saver.restore(sess,tf.train.latest_checkpoint('./'))

2.2 从ckpt文件导入后继续训练

恢复任何预训练的模型,并用它进行inference,fine-tuning或者进一步训练。在tensorflow中,如果有占位符,那么需要将数据传入占位符,但在保存model时,占位符数据不被保存

import tensorflow as tf
 
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
 
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
 
#Create a saver object which will save all the variables
saver = tf.train.Saver()
 
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 
 
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

我们需要恢复时,不仅需要恢复网络结构和相关变量参数,而且还需要准备新的feed_dic传入占位符。通过graph,get_tensor_by_name() 方法可以恢复所保存的占位符和opertor。比如下面的W1是一个占位符,op_to_restore是一个算子。

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")
 
## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

完整的example:
import tensorflow as tf;
import os;

model_saving_path = "./checkpoint"
model_name = 'saving_restoring';


def save():
    w1 = tf.placeholder(dtype=tf.float32, name='w1');
    w2 = tf.placeholder(dtype=tf.float32, name='w2');
    b1 = tf.Variable(2.0, name='bias');
    feed_dict = {w1:4, w2:8};

    w3 = tf.add(w1, w2)
    w4 = tf.multiply(w3, b1, name='op_to_restore');
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver();
        print(sess.run(w4, feed_dict));
        saver.save(sess, os.path.join(model_saving_path, model_name), global_step=1000);


def restore0():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_saving_path, model_name+'-1000.meta'))
        saver.restore(sess, tf.train.latest_checkpoint(model_saving_path))

        graph = tf.get_default_graph();
        w1 = graph.get_tensor_by_name('w1:0');
        w2 = graph.get_tensor_by_name('w2:0');
        feed_dict = {w1:13.0, w2:17.0};

        op_to_restore = graph.get_tensor_by_name('op_to_restore:0');
        print(sess.run(op_to_restore, feed_dict))


def restore():
"""不能以这样的方式恢复占位符,会报错:
InvalidArgumentError (see above for traceback):
 You must feed a value for placeholder tensor 'w1_1' with dtype float
因为对于一个占位符而言,它所包含的不仅仅是占位符变量的定义部分,
还包含数据,而tensorflow不保存占位符的数据部分。
应通过graph.get_tensor_by_name的方式获取,然后在feed数据进去"""

    w1 = tf.placeholder(dtype=tf.float32, name='w1');
    w2 = tf.placeholder(dtype=tf.float32, name='w2');
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_saving_path, model_name+'-1000.meta'))
        saver.restore(sess, tf.train.latest_checkpoint(model_saving_path))

        graph = tf.get_default_graph();
        # w1 = graph.get_tensor_by_name('w1:0');
        # w2 = graph.get_tensor_by_name('w2:0');
        feed_dict = {w1:13.0, w2:17.0};

        op_to_restore = graph.get_tensor_by_name('op_to_restore:0');
        print(sess.run(op_to_restore, feed_dict))

save()
restore0();

2.3 从ckpt文件恢复训练,并修改模型结构

在原来的神经网络加更多的层,继续训练

def restore2():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_saving_path, model_name+'-1000.meta'))
        saver.restore(sess, tf.train.latest_checkpoint(model_saving_path))

        graph = tf.get_default_graph();
        w1 = graph.get_tensor_by_name('w1:0');
        w2 = graph.get_tensor_by_name('w2:0');
        feed_dict = {w1:13.0, w2:17.0};

        op_to_restore = graph.get_tensor_by_name('op_to_restore:0');
        # Add more to the current graph
        add_on_op = tf.multiply(op_to_restore, 2)
        print(sess.run(add_on_op, feed_dict))
        # This will print 120

如果我只想恢复神经网络的一部分参数或者一部分算子,然后利用这一部分参数或者算子构建新的神经网络模型:我们可以使用graph.get_tensor_by_name() 方法。下面是个例子,在这里我们使用.meta加载一个预训练好的VGG网络,并做一些修改:

saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值