TensorFlow 模型保存和加载

在TensorFlow中,我们希望模型训练完成后将模型保存下来,等到测试阶段再从文件中加载训练好的模型进行测试。这就要用到tf.train.Saver()这个类。

保存一个TensorFlow的模型

下面举一个简单的例子利用 tf.train.Saver 来保存模型。

import tensorflow as tf
import os
import numpy as np
 
a = tf.Variable(tf.random_normal(shape=[2,2]), name='a')
b = tf.Variable(tf.random_normal(shape=[2,2]), name='b')
 
model_path = './model/'  # 模型保存路径
model_name = 'test_model'  # 模型保存文件名称
model_save_path = os.path.join(model_path, model_name) # 路径+文件名
 
saver = tf.train.Saver()  # 创建一个Saver对象
# saver = tf.train.Saver(max_to_keep=1)  # max_to_keep :设置保存模型的个数,默认为5
 
with tf.Session() as save_sess:
    save_sess.run(tf.global_variables_initializer())
    c = save_sess.run(tf.add(a, b))
    saver.save(sess=save_sess, save_path=model_save_path, global_step=step) # 保存模型
    print("save model success") 

接下来我们详细看一下tf.train.Saver.save 函数

save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True, strip_default_attrs=False)
'''
Args:
  sess: A Session to use to save the variables.
  save_path: String.  Prefix of filenames created for the checkpoint.
  global_step: If provided the global step number is appended to `save_path` to create the checkpoint filenames. The optional argument can be a `Tensor`, a `Tensor` name or an integer.
'''
  • 第一个参数 sess=sess,即会话名称
  • 第二个参数 save_path=model_save_path,设定权重参数保存到的路径和文件名称
  • 第三个参数 global_step=step,将训练的次数作为后缀加入到模型名字中

模型保存结果

运行上述代码后,我们可以看到以下几个文件
在这里插入图片描述

  • .meta文件:meta graph保存了tensorflow的graph。包括variables,operations,collections等等。
  • .data-00000-of-00001和.index文件:二进制文件,保存了所有weights,biases,gradient and all the other variables的值。
  • checkpoint的文件,保存最新检查点文件的记录。

导入训练好的模型

在训练好的模型中,.meta文件中已经保存了整个graph,我们无需重建,只要导入上面保存好的文件即可。

with tf.Session() as test_sess:
    ckpt = tf.train.get_checkpoint_state(model_save_path)
    if ckpt and ckpt.model_checkpoint_path: # 判断checkpoint是否存在
        saver.restore(test_sess,ckpt.model_checkpoint_path) # 加载模型中各种变量的值
        # saver.restore(test_sess, './model/test_model') # 这里不用加文件的后缀
    print("load model success")

接下来我们详细看一下tf.train.Saver.restore 函数

restore(self, sess, save_path)
'''
Restores previously saved variables.
Args:
  sess: A `Session` to use to restore the parameters. None in eager mode.
  save_path: Path where parameters were previously saved.
'''

附一个包含有 训练模型 -> 保存模型 -> 加载模型 -> 模型预测 过程的完整代码:(代码来源 https://blog.csdn.net/u012856866/article/details/104699327

第一个文件,训练模型并保存模型:

#定义模型
X = tf.placeholder(tf.float32,shape = [None,x_dim],name = 'X')
Y = tf.placeholder(tf.float32,shape = [None,1], name = 'Y')
W = tf.Variable(tf.random_normal([x_dim,1]),name='weight')
b = tf.Variable(tf.random_normal([1]),name='bias')

hypothesis = tf.sigmoid(tf.matmul(X,W)+b)
cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

# tf.add_to_collection:保存hypothesis和cost,以便重新导入模型时可以使用
tf.add_to_collection(name='hypothesis', value=hypothesis) 
tf.add_to_collection(name='cost', value=cost)

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(50):
    avg_cost, _ = sess.run([cost,train],feed_dict = {X:x_data,Y:y_data})

mysaver.save(sess, './model/model_LR_test') #保存模型

第二个文件,加载模型,并利用训练好的模型预测:

sess = tf.Session()
#本来我们需要重新像上一个文件那样重新构建整个graph,但是利用下面这个语句就可以加载整个graph了,方便
new_saver = tf.train.import_meta_graph('../model/model_LR_test.meta')
new_saver.restore(sess,'../model/model_LR_test')#加载模型中各种变量的值,注意这里不用文件的后缀

#对应第一个文件的add_to_collection()函数
hyp = tf.get_collection('hypothesis')[0] #返回值是一个list,我们要的是第一个,这也说明可以有多个变量的名字一样。

graph = tf.get_default_graph() 
X = graph.get_operation_by_name('X').outputs[0]#为了将placeholder加载出来

pred = sess.run(hyp,feed_dict = {X:x_valid})
print('auc:',auc(y_valid,pred))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值