Tensorflow---模型持久化的相关问题

Tensorflow—模型持久化的相关问题

TensorFlow使用tf.train.Saver类实现模型的保存和提取。

– 通过Saver对象的restore方法可以加载模型,并通过保存好的模型变量相关值重新加载完全加载进来。
– 如果不希望重复定义计算图上的运算,可以直接加载已经持久化的图,通过tf.train.import_meta_graph方法直接加载

保存模型

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 加几个随机数种子,让多次运行的时候,随机数列一致
np.random.seed(28)
tf.set_random_seed(28)

if __name__ == '__main__':
    with tf.Graph().as_default():
        # 一、执行图的构建
        with tf.variable_scope('network'):
            # a. 定义占位符
            input_x = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='x')
            input_y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')

            # b. 定义模型参数
            w = tf.get_variable(name='w2', shape=[2, 1], dtype=tf.float32,
                                initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0))
            b = tf.get_variable(name='b2', shape=[1], dtype=tf.float32,
                                initializer=tf.zeros_initializer())

            # c. 模型预测的构建(获取预测值)
            y_ = tf.matmul(input_x, w) + b

        with tf.name_scope('loss'):
            # d. 损失函数构建(平方和损失函数)
            loss = tf.reduce_mean(tf.square(input_y - y_))
            tf.summary.scalar('loss', loss)
            print(loss)

        with tf.name_scope('train'):
            # e. 定义优化器(优化器的意思:求解让损失函数最小的模型参数<变量>的方式)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
            # f. 定义一个训练操作对象
            train_op = optimizer.minimize(loss=loss)

        # 二、执行图的训练运行
        with tf.Session() as sess:
            # a. 创建一个持久化对象
            saver = tf.train.Saver()

            # a. 变量的初始化操作
            sess.run(tf.global_variables_initializer())

            # 获取一个日志输出对象
            writer = tf.summary.FileWriter(logdir='./models/14/graph', graph=sess.graph)
            # 获取所有的summary输出操作
            summary = tf.summary.merge_all()

            # b. 训练数据的产生/获取(基于numpy随机产生<可以先考虑一个固定的数据集>)
            N = 100
            dim = 2
            x = np.random.uniform(low=-10, high=10, size=(N, dim))
            y = np.dot(x, [[5], [-0.5]]) + 12 + np.random.normal(0, 5.0, (N, 1))
            x.shape = -1, dim
            y.shape = -1, 1
            print((np.shape(x), np.shape(y)))

            # c. 模型训练
            for step in range(100):
                # 1. 触发模型训练操作
                _, loss_, summary_ = sess.run([train_op, loss, summary], feed_dict={
                    input_x: x,
                    input_y: y
                })
                print("第{}次训练后模型的损失函数为:{}".format(step, loss_))
                writer.add_summary(summary_, global_step=step)
                # 触发模型持久化
                save_path = './models/14/model/model.ckpt'
                dirpath = os.path.dirname(save_path)
                if not os.path.exists(dirpath):
                    os.makedirs(dirpath)
                saver.save(sess, save_path=save_path) #

            # 关闭输出流
            writer.close()

上面的代码是用tf实现线性回归的简单代码,我们直接在上面进行加模型的持久化操作~
模型持久化的步骤为:

#第一步:创建一个持久化对象
saver = tf.train.Saver()
#第二步:触发模型持久化
save_path = './models/14/model/model.ckpt'
dirpath = os.path.dirname(save_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
saver.save(sess, save_path=save_path)

其中,tf.train.Saver()参数为:

def __init__(self,
               var_list=None, 给定具体持久化那些模型参数,默认是持久化所有的变量<参与模型训练的>
               reshape=False,
               sharded=False,
               max_to_keep=5, 指定最多同时保留最近多少份模型
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None):

比如只想保存其中部分的参数,例如只保存参数w,那么就是:

saver = tf.train.Saver([w])

在tensorflow中,from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file这个api可以查看ckpt文件中保存的参数到底是个啥~

#还是用上述的线性回归代码:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("D:/chrome_down/05_tf/tensorflow14/tf/models/14/model/model.ckpt",None,True)
'''
运行结果为:
tensor_name:  network/b2
[10.069975]
tensor_name:  network/w2
[[ 5.0179863 ]
 [-0.61515224]]
 
~~~~~通过上面的结果,得知ckpt模型保存了w和b两个参数~~~~~

如果saver = tf.train.Saver([w])的话,运行结果为:
tensor_name:  network/w2
[[ 5.0179863]
 [-0.6151727]]
 ~~~~~通过上面的结果,得知ckpt模型只保存了w两个参数~~~~~
'''

在运行代码结束之后,会在对应的路径生成以下文件:
在这里插入图片描述
下面,讲解以下这四个文件的作用:
checkpoint:checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。(如果不指定保存哪些参数,会默认保存所有参数)
model.ckpt.data:保存模型中的参数值。
model.ckpt.index:保存模型的参数名。
model.ckpt.meta:保存图结构。

值得注意的是:当saver.save参数global_step设置为每一步的时候,最后会生成五个最新的步骤生成的ckpt文件~

加载模型

当有保存好的模型,我们可以选择加载它们~

save_path = './models/14/model/model.ckpt'
saver.restore(sess, save_path)
'''如果在模型回复的过程中,参数名字发生改变,加下面一句代码:
saver = tf.train.Saver({"network/w2": w, "network/b2": b})
其中,w和b是名称修改后的tensor对象,w2和b2是模型保存时候的名称
'''

或者:

# 获取持久化的信息对象
            ckpt = tf.train.get_checkpoint_state('./models/18/model')
            print(ckpt.model_checkpoint_path)
            if ckpt and ckpt.model_checkpoint_path:
                print("进行模型恢复操作...")
                # 恢复模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 恢复checkpoint的管理信息
                saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
            else:
                # 如果文件不存在,进行初始化
                print("进行模型参数初始化操作...")
                sess.run(tf.global_variables_initializer())

其中,tf.train.get_checkpoint_state的作用是:通过checkpoint文件找到模型文件名~~

如果,代码中没有图的构建,直接通过图的恢复:

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 加几个随机数种子,让多次运行的时候,随机数列一致
np.random.seed(28)
tf.set_random_seed(28)

if __name__ == '__main__':
    with tf.Graph().as_default():
        with tf.Session() as sess:
            # 恢复图中的执行信息
            ckpt = tf.train.get_checkpoint_state('./models/14/model')
            if ckpt is None or ckpt.model_checkpoint_path is None:
                raise Exception("没有持久化好的模型!!!")
            saver = tf.train.import_meta_graph(meta_graph_or_file="{}.meta".format(ckpt.model_checkpoint_path))

            # a. 恢复模型
            saver.restore(sess, ckpt.model_checkpoint_path)

            # b. 训练数据的产生/获取(基于numpy随机产生<可以先考虑一个固定的数据集>)
            N = 10
            dim = 2
            x = np.random.uniform(low=-10, high=10, size=(N, dim))
            y = np.dot(x, [[5], [-0.5]]) + 12 + np.random.normal(0, 5.0, (N, 1))
            x.shape = -1, dim
            y.shape = -1, 1
            print((np.shape(x), np.shape(y)))

            # 测试的误差
            loss = tf.get_default_graph().get_tensor_by_name('loss/Mean:0')
            y_ = tf.get_default_graph().get_tensor_by_name('network/add:0')
            input_x = tf.get_default_graph().get_tensor_by_name('network/x:0')
            input_y = tf.get_default_graph().get_tensor_by_name('network/y:0')
            loss_, predict = sess.run([loss, y_], feed_dict={
                input_x: x,
                input_y: y
            })
            print("模型测试的损失函数为:{}".format(loss_))
            print("预测值为:{}".format(np.reshape(predict, -1)))
            print("实际值:{}".format(np.reshape(y, -1)))

通常,我们直接在代码中加一段代码就可以了~

            ckpt = tf.train.get_checkpoint_state('./models/18/model')
            if ckpt and ckpt.model_checkpoint_path:
                print("进行模型恢复操作...")
                # 恢复模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 恢复checkpoint的管理信息
                saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
            else:
                # 如果文件不存在,进行初始化
                print("进行模型参数初始化操作...")
                sess.run(tf.global_variables_initializer())

其中一些函数的解释:

  1. tf.train.get_checkpoint_state函数:通过checkpoint文件找到模型文件名。如果print它,就会出现
    在这里插入图片描述
    本质上其实就是找到保存的最新的5个模型~
    2.ckpt.model_checkpoint_path:这个函数的目的是找到最新的模型的路径,print的结果为:
    在这里插入图片描述

可以参考的博客:https://zhuanlan.zhihu.com/p/45918984 --模型保存
https://zhuanlan.zhihu.com/p/46088787 --模型加载
(我感觉这个博主写的很详细~)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值