Tensorflow—模型持久化的相关问题
TensorFlow使用tf.train.Saver类实现模型的保存和提取。
– 通过Saver对象的restore方法可以加载模型,并通过保存好的模型变量相关值重新加载完全加载进来。
– 如果不希望重复定义计算图上的运算,可以直接加载已经持久化的图,通过tf.train.import_meta_graph方法直接加载
保存模型
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
# 加几个随机数种子,让多次运行的时候,随机数列一致
np.random.seed(28)
tf.set_random_seed(28)
if __name__ == '__main__':
with tf.Graph().as_default():
# 一、执行图的构建
with tf.variable_scope('network'):
# a. 定义占位符
input_x = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='x')
input_y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')
# b. 定义模型参数
w = tf.get_variable(name='w2', shape=[2, 1], dtype=tf.float32,
initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0))
b = tf.get_variable(name='b2', shape=[1], dtype=tf.float32,
initializer=tf.zeros_initializer())
# c. 模型预测的构建(获取预测值)
y_ = tf.matmul(input_x, w) + b
with tf.name_scope('loss'):
# d. 损失函数构建(平方和损失函数)
loss = tf.reduce_mean(tf.square(input_y - y_))
tf.summary.scalar('loss', loss)
print(loss)
with tf.name_scope('train'):
# e. 定义优化器(优化器的意思:求解让损失函数最小的模型参数<变量>的方式)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
# f. 定义一个训练操作对象
train_op = optimizer.minimize(loss=loss)
# 二、执行图的训练运行
with tf.Session() as sess:
# a. 创建一个持久化对象
saver = tf.train.Saver()
# a. 变量的初始化操作
sess.run(tf.global_variables_initializer())
# 获取一个日志输出对象
writer = tf.summary.FileWriter(logdir='./models/14/graph', graph=sess.graph)
# 获取所有的summary输出操作
summary = tf.summary.merge_all()
# b. 训练数据的产生/获取(基于numpy随机产生<可以先考虑一个固定的数据集>)
N = 100
dim = 2
x = np.random.uniform(low=-10, high=10, size=(N, dim))
y = np.dot(x, [[5], [-0.5]]) + 12 + np.random.normal(0, 5.0, (N, 1))
x.shape = -1, dim
y.shape = -1, 1
print((np.shape(x), np.shape(y)))
# c. 模型训练
for step in range(100):
# 1. 触发模型训练操作
_, loss_, summary_ = sess.run([train_op, loss, summary], feed_dict={
input_x: x,
input_y: y
})
print("第{}次训练后模型的损失函数为:{}".format(step, loss_))
writer.add_summary(summary_, global_step=step)
# 触发模型持久化
save_path = './models/14/model/model.ckpt'
dirpath = os.path.dirname(save_path)
if not os.path.exists(dirpath):
os.makedirs(dirpath)
saver.save(sess, save_path=save_path) #
# 关闭输出流
writer.close()
上面的代码是用tf实现线性回归的简单代码,我们直接在上面进行加模型的持久化操作~
模型持久化的步骤为:
#第一步:创建一个持久化对象
saver = tf.train.Saver()
#第二步:触发模型持久化
save_path = './models/14/model/model.ckpt'
dirpath = os.path.dirname(save_path)
if not os.path.exists(dirpath):
os.makedirs(dirpath)
saver.save(sess, save_path=save_path)
其中,tf.train.Saver()参数为:
def __init__(self,
var_list=None, 给定具体持久化那些模型参数,默认是持久化所有的变量<参与模型训练的>
reshape=False,
sharded=False,
max_to_keep=5, 指定最多同时保留最近多少份模型
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
比如只想保存其中部分的参数,例如只保存参数w,那么就是:
saver = tf.train.Saver([w])
在tensorflow中,from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file这个api可以查看ckpt文件中保存的参数到底是个啥~
#还是用上述的线性回归代码:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("D:/chrome_down/05_tf/tensorflow14/tf/models/14/model/model.ckpt",None,True)
'''
运行结果为:
tensor_name: network/b2
[10.069975]
tensor_name: network/w2
[[ 5.0179863 ]
[-0.61515224]]
~~~~~通过上面的结果,得知ckpt模型保存了w和b两个参数~~~~~
如果saver = tf.train.Saver([w])的话,运行结果为:
tensor_name: network/w2
[[ 5.0179863]
[-0.6151727]]
~~~~~通过上面的结果,得知ckpt模型只保存了w两个参数~~~~~
'''
在运行代码结束之后,会在对应的路径生成以下文件:
下面,讲解以下这四个文件的作用:
checkpoint:checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。(如果不指定保存哪些参数,会默认保存所有参数)
model.ckpt.data:保存模型中的参数值。
model.ckpt.index:保存模型的参数名。
model.ckpt.meta:保存图结构。
值得注意的是:当saver.save参数global_step设置为每一步的时候,最后会生成五个最新的步骤生成的ckpt文件~
加载模型
当有保存好的模型,我们可以选择加载它们~
save_path = './models/14/model/model.ckpt'
saver.restore(sess, save_path)
'''如果在模型回复的过程中,参数名字发生改变,加下面一句代码:
saver = tf.train.Saver({"network/w2": w, "network/b2": b})
其中,w和b是名称修改后的tensor对象,w2和b2是模型保存时候的名称
'''
或者:
# 获取持久化的信息对象
ckpt = tf.train.get_checkpoint_state('./models/18/model')
print(ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
print("进行模型恢复操作...")
# 恢复模型
saver.restore(sess, ckpt.model_checkpoint_path)
# 恢复checkpoint的管理信息
saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
else:
# 如果文件不存在,进行初始化
print("进行模型参数初始化操作...")
sess.run(tf.global_variables_initializer())
其中,tf.train.get_checkpoint_state的作用是:通过checkpoint文件找到模型文件名~~
如果,代码中没有图的构建,直接通过图的恢复:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
# 加几个随机数种子,让多次运行的时候,随机数列一致
np.random.seed(28)
tf.set_random_seed(28)
if __name__ == '__main__':
with tf.Graph().as_default():
with tf.Session() as sess:
# 恢复图中的执行信息
ckpt = tf.train.get_checkpoint_state('./models/14/model')
if ckpt is None or ckpt.model_checkpoint_path is None:
raise Exception("没有持久化好的模型!!!")
saver = tf.train.import_meta_graph(meta_graph_or_file="{}.meta".format(ckpt.model_checkpoint_path))
# a. 恢复模型
saver.restore(sess, ckpt.model_checkpoint_path)
# b. 训练数据的产生/获取(基于numpy随机产生<可以先考虑一个固定的数据集>)
N = 10
dim = 2
x = np.random.uniform(low=-10, high=10, size=(N, dim))
y = np.dot(x, [[5], [-0.5]]) + 12 + np.random.normal(0, 5.0, (N, 1))
x.shape = -1, dim
y.shape = -1, 1
print((np.shape(x), np.shape(y)))
# 测试的误差
loss = tf.get_default_graph().get_tensor_by_name('loss/Mean:0')
y_ = tf.get_default_graph().get_tensor_by_name('network/add:0')
input_x = tf.get_default_graph().get_tensor_by_name('network/x:0')
input_y = tf.get_default_graph().get_tensor_by_name('network/y:0')
loss_, predict = sess.run([loss, y_], feed_dict={
input_x: x,
input_y: y
})
print("模型测试的损失函数为:{}".format(loss_))
print("预测值为:{}".format(np.reshape(predict, -1)))
print("实际值:{}".format(np.reshape(y, -1)))
通常,我们直接在代码中加一段代码就可以了~
ckpt = tf.train.get_checkpoint_state('./models/18/model')
if ckpt and ckpt.model_checkpoint_path:
print("进行模型恢复操作...")
# 恢复模型
saver.restore(sess, ckpt.model_checkpoint_path)
# 恢复checkpoint的管理信息
saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
else:
# 如果文件不存在,进行初始化
print("进行模型参数初始化操作...")
sess.run(tf.global_variables_initializer())
其中一些函数的解释:
- tf.train.get_checkpoint_state函数:通过checkpoint文件找到模型文件名。如果print它,就会出现
本质上其实就是找到保存的最新的5个模型~
2.ckpt.model_checkpoint_path:这个函数的目的是找到最新的模型的路径,print的结果为:
可以参考的博客:https://zhuanlan.zhihu.com/p/45918984 --模型保存
https://zhuanlan.zhihu.com/p/46088787 --模型加载
(我感觉这个博主写的很详细~)