TensorFlow模型的保存与加载(一)——checkpoint模式【源码】

如果这篇文章对您有帮助,欢迎点赞支持!

目录

前言

1、TF模型保存方法

2、checkpoint模式

3、适合保存模型的时机

一、保存模型

1、创建Saver对象

2、检查保存路径

3、生成模型文件

二、加载模型

1、加载模型结构

2、加载模型参数

3、获取Tensor变量

三、代码封装

1、保存网络模型

2、加载网络模型


前言

1、TF模型保存方法

网络模型的保存和重载操作是学习和训练AI模型的必备技能之一,也是进一步学习迁移学习知识的基础。

Tensorflow的模型保存加载有不同格式,使用方法也不一样。

目前来看,Tensorflow的模型保存方式按照生成的主要文件的格式基本可以分为三种:(1)checkpoint模式;(2) pb模式;(3)saved_model模式

2、checkpoint模式

checkpoint模式是最常见的一种保存方式,其特点是将网络结构变量数据分开保存,该方式保存的模型基本只能使用TensorFlow来重载,保存好的模型的文件结构如下:

|--checkpoint_dir
|    |--checkpoint
|    |--test-model-550.meta
|    |--test-model-550.data-00000-of-00001
|    |--test-model-550.index

checkpoint_dir就是保存时我们指定的路径,该路径下会生成4个文件(文件中的550是指该模型训练了550个step)

其中.meta文件(本质是pb格式文件)用来保存模型的网络结构.data.index文件用来保存模型中的各种变量

checkpoint文件里面记录了最新的checkpoint文件以及其它checkpoint文件列表,在inference时可以通过修改这个文件,指定重载使用哪个model。

3、适合保存模型的时机

尽可能多地保存模型能帮助我们不错过效果最好的模型,但是实际操作也要考虑内存大小运行效率

工程师的通用做法是每训练多少步后就在验证集上计算一次准确率,如果本次结果比上次好则保存新的模型。

最常见的做法是直接每训练多少步就直接保存一次模型,实际上TensorFlow的API也提供了在模型文件名提供添加训练step后缀的方式。


一、保存模型

1、创建Saver对象

TensorFlow提供了tf.train.Saver()对象来保存和还原一个机器学习模型。其常见用法如下:

saver = tf.train.Saver(max_to_keep=3)

在创建这个Saver对象的时候,经常会用到max_to_keep 参数来设置保存模型的个数,默认max_to_keep=5,即保存最新的5个模型。

max_to_keep是指本次训练在checkpoint_dir这个路径下最多保存多少个模型文件,新模型会覆盖旧模型以节省空间。

如果想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,但是这样做除了多占用硬盘,并没有实际多大的用处,不推荐设置。

2、检查保存路径

保存模型必然需要先检查下保存路径是否存在,一段可以检查Python路径的代码如下:

# 检查模型保存路径
ckpt_file_path = "./model/"
path = os.path.abspath(ckpt_file_path)
if os.path.exists(path) is False:
    os.makedirs(path)

注意:保存路径中及其所在的绝对路径中最好不要有中文路径,容易出错。

3、生成模型文件

TensorFlow提供了如下API来保存一个计算图模型:

sess = tf.Session() # 创建tf会话
tf.train.Saver().save(sess, ckpt_file_path,global_step=step,write_meta_graph=True)

其中第1个参数是创建的tf会话,tf会话中加载有计算图graph对象,可以通过此获取当前要保存的计算图模型;

第2个参数是模型保存的文件路径及其文件名;

第3个参数是训练的次数,该API会自动将训练的次数作为后缀加入到模型名字中;

第4个参数是是否生成.meta文件,该文件包含了模型的网络结构。

该API正常执行后程序会生成并保存如下4类文件:

image.png

注意,其中checkpoint是文本文件,在训练过程中多次保存模型时只要保存路径不变,该文件将只有一份。其内容大致如下:

image.png


二、加载模型

1、加载模型结构

加载模型结构有两种方式:1. 加载.meta文件中的结构, 2. 手动重新写一遍原样结构。

第2种方式适合在有源码的情况下进行生成,第1种方式则是在没有源码的情况下生成的,其主要是加载.meta文件

saver = tf.train.import_meta_graph('./model/dog-cat.ckpt-9975.meta')# 加载模型结构

2、加载模型参数

加载模型参数就是加载模型的变量,主要方法是用saver.restore()方法恢复变量:

# 加载网络模型
sess = tf.Session() # 创建tf会话
# saver.restore(sess, './model/dog-cat.ckpt-9975')# 指定特定模型
saver.restore(sess, tf.train.latest_checkpoint('./model/'))# 使用最新模型
sess.run(tf.global_variables_initializer()) # 重新初始化模型变量

该API中第1个参数是创建的tf会话,加载的模型将会保存到该会话中;

第2个参数是模型保存的路径,不需要提供模型的名字

3、获取Tensor变量

在上述获得模型结构及其参数后,就可以使用该模型了,现在模型已经被加载到sess中,我们可以获取此时对应的计算图,通过图获取其中的Tensor变量,注意这里获取的Tensor变量名必须在保存时有定义。

graph = tf.get_default_graph() # 获取计算图
self.x = graph.get_tensor_by_name("x:0")
self.y_true = graph.get_tensor_by_name("y_true:0")
self.y_pred = graph.get_tensor_by_name("y_pred:0")

其中x是节点名称,x:0是表述节点的输出的第一个张量


三、代码封装

在上述内容我们只是介绍了方法步骤,而且路径的输入上也存在一些技巧,所以我们实际使用时要将上述步骤封装起来,下面我们将上述方法封装起来并给出测试代码和运行效果

1、保存网络模型

为了让读者将精力集中到模型的保存和加载上,这里我们使用一个非常简单的计算图结构x*y+b

# 定义计算图结构 x*y+b
x = tf.placeholder(tf.int32, name='x_input')
y = tf.placeholder(tf.int32, name='y_input')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output')

运行上述计算图,相关测试代码如下:

# 输出计算图结果
sess = tf.Session()
sess.run(tf.global_variables_initializer())
y_pred = sess.run(output, {x: 10, y: 3})
print(y_pred) # 输出31

最后输出结果是10*3+1=31,和我们预想的一致。

接下来我们使用如下方法保存该计算图模型:

# 使用ckpt保存计算图
save_ckpt_model(sess,'./model/',1)

save_ckpt_model()是我们封装的方法,其相关代码如下:

def save_ckpt_model(sess, save_path,train_step=None):
    saver = tf.train.Saver(max_to_keep=5)
    # 检查路径是否存在
    path = os.path.abspath(save_path)  # 获取绝对路径
    if os.path.exists(path) is False:
        os.makedirs(path)
        print("成功创建模型保存新路径:{}".format(path))
    saver.save(sess, save_path + 'model.ckpt', global_step=train_step, write_meta_graph=True)  # 保存为ckpt模型
    print("成功使用CKPT模式保存模型到路径:{}".format(path))

成功运行后,会生成如下文件:

image.png

注意,运行代码中,代码编辑器可能不会马上更新文件目录,建议打开目录查看下是否生成。

2、加载网络模型

现在我们要将上述保存的计算图结构x*y+b从文件中重载出来,下面直接给出加载代码:

# 加载计算图结构 x*y+b
sess = tf.Session()
load_ckpt_model(sess, './model/')
# 获取计算图节点
graph = tf.get_default_graph()  # 获取计算图
x = graph.get_tensor_by_name("x_input:0")
y = graph.get_tensor_by_name("y_input:0")
output= graph.get_tensor_by_name("output:0")
# 运行计算图计算结果
y_pred = sess.run(output, {x: 10, y: 3})
print(y_pred) # 输出 31

其中加载方法load_ckpt_model()代码如下:

def load_ckpt_model(sess, save_path):
    checkpoint = tf.train.get_checkpoint_state(save_path)  # 从checkpoint文件中读取checkpoint对象
    input_checkpoint = checkpoint.model_checkpoint_path
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)  # 加载模型结构
    saver.restore(sess, input_checkpoint)  # 使用最新模型
    sess.run(tf.global_variables_initializer())# 初始化所有变量

最后,如果这篇文章对您有帮助,欢迎点赞支持,感谢您的善意!

  • 22
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔法攻城狮MRL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值