将训练好的模型保存并使用

这篇文章为说明将训练好的模型保存为ckpt后,再生成PB文件,并调用的方法。其中,生成和调用的部分代码是统一的格式,可以收藏。训练图用的是前一篇博文泰坦尼克生存率预测,可以参考:https://blog.csdn.net/aspirinLi/article/details/105350452。由于这个图的输入和输出都较为简单,测试模型部分也简化了,图像的输入还需要找机会尝试。

一、首先将训练好的模型保存为检查点

1、在Session()前写上代码:

saver = tf.train.Saver()

2、在Session()结束时写上:

saver.save(sess,'./model/model.ckpt',global_step=i+1)

以下是预测泰坦尼克代码的部分:

saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
    loss_train = []
    train_acc = []
    test_acc = []
    sess.run(init)
    for i in range(10000):
        for n in range(len(data_target)//100+1):
            batch_xs = data_train[n*100:n*100+100]
            batch_ys = data_target[n * 100:n * 100 + 100]
            sess.run(step,feed_dict={x:batch_xs,y:batch_ys})
        if i%1000 == 0:
            loss_temp = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
            loss_train.append(loss_temp)
            train_acc_temp = sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys})
            train_acc.append(train_acc_temp)
            test_acc_temp = sess.run(accuracy,feed_dict={x:data_test,y:test_label})
            test_acc.append(test_acc_temp)
            print(loss_temp,train_acc_temp,test_acc_temp)
            saver.save(sess,'./model/model.ckpt',global_step=i+1)

泰坦尼克完整代码见:https://blog.csdn.net/aspirinLi/article/details/105350452

于是,我们就可以边训练边保存模型,完成后生成下列文件:

 

二、生成PB文件

首先提示一下,可以将需要固化的值在训练的时候就给出名字,这样取值会方便。比如:

x = tf.placeholder(shape=[None,12],dtype = tf.float32,name='input_x')

output = tf.add(tf.matmul(x,weight),bias,name = 'output')

接下来生成PB文件:

import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np


def freeze_graph(input_checkpoint, output_graph):
    #这里选择需要固化的量
    output_node_names = "input_x,output"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

# 输入ckpt模型路径
input_checkpoint='D:/TensorflowTest/taitanic/model/model.ckpt-9001'
# 输出pb模型的路径
out_pb_path="D:/TensorflowTest/taitanic/model/frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)

三、调用PB文件

这里选择的测试文件是test.csv。前面的处理参考泰坦尼克预测存活率那篇博文。

import tensorflow as tf
import pandas as pd
import os
import numpy as np
from tensorflow.python.platform import gfile


# 先检测看pb文件是否存在
savePbFile = './model/frozen_model.pb'
if os.path.exists(savePbFile) is False:
    print('Not found pb file!')
    exit()


data_test = pd.read_csv('./data/test.csv')
data_test = data_test[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Cabin', 'Embarked']]
data_test['Age'] = data_test['Age'].fillna(data_test['Age'].mean())
data_test['Cabin'] = pd.factorize(data_test.Cabin)[0]
data_test.fillna(0, inplace=True)
data_test['Sex'] = [1 if x == 'male' else 0 for x in data_test.Sex]
data_test['p1'] = np.array(data_test['Pclass'] == 1).astype(np.int32)
data_test['p2'] = np.array(data_test['Pclass'] == 2).astype(np.int32)
data_test['p3'] = np.array(data_test['Pclass'] == 3).astype(np.int32)
del data_test['Pclass']
data_test['e1'] = np.array(data_test['Embarked'] == 'S').astype(np.int32)
data_test['e2'] = np.array(data_test['Embarked'] == 'C').astype(np.int32)
data_test['e3'] = np.array(data_test['Embarked'] == 'Q').astype(np.int32)
del data_test['Embarked']

#print(data_test)


#data_test = data_test.astype()
gender = pd.read_csv('./data/gender.csv')
gender = np.reshape(gender.Survived.values.astype(np.float32),(418,1))


with tf.Session() as sess:
    # 打开pb模型文件
    with gfile.FastGFile(savePbFile, 'rb') as fd:
        # 导入图
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fd.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        # 根据名字获取对应的tensorflow
        input = sess.graph.get_tensor_by_name('input_x:0')
        output = sess.graph.get_tensor_by_name('output:0')
        output_e = tf.cast(tf.sigmoid(output)>0.5,tf.float32)
        acc = tf.reduce_mean(tf.cast(tf.equal(gender, output_e),dtype = tf.float32))
        result = sess.run(output_e, feed_dict={input:data_test})
        accuracy = sess.run(acc,feed_dict={input:data_test})
        print(result)
        print(accuracy)

即可。

要调用保存好的 TensorFlow 模型,可以使用以下步骤: 1. 定义模型结构和训练过程。 2. 创建一个 `tf.train.Saver` 对象,用于保存和恢复模型。 3. 在训练结束后,调用 `saver.save()` 方法保存模型。 4. 在测试或预测过程中,使用 `tf.train.import_meta_graph()` 方法加载模型的图结构。 5. 创建一个 `tf.Session` 对象,并使用 `saver.restore()` 方法恢复模型的参数。 6. 在 `Session` 中执行模型的前向传播操作,获取预测结果。 以下是一个简单的示例代码,展示如何加载保存好的 TensorFlow 模型: ```python import tensorflow as tf # 定义模型结构和训练过程 x = tf.placeholder(tf.float32, [None, 784], name='x') y = tf.placeholder(tf.float32, [None, 10], name='y') w = tf.Variable(tf.zeros([784, 10]), name='w') b = tf.Variable(tf.zeros([10]), name='b') logits = tf.matmul(x, w) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 创建 Saver 对象 saver = tf.train.Saver() with tf.Session() as sess: # 恢复模型的图结构 saver = tf.train.import_meta_graph('model.ckpt.meta') # 加载模型的参数 saver.restore(sess, 'model.ckpt') # 获取模型的输入和输出张量 graph = tf.get_default_graph() x = graph.get_tensor_by_name('x:0') y = graph.get_tensor_by_name('y:0') logits = graph.get_tensor_by_name('add:0') # 执行模型的前向传播操作 predictions = tf.argmax(logits, axis=1) test_data = ... test_labels = ... feed_dict = {x: test_data, y: test_labels} results = sess.run(predictions, feed_dict=feed_dict) ``` 在上述代码中,我们首先定义了一个简单的模型结构和训练过程,并使用 `tf.train.Saver` 对象保存模型。在测试或预测过程中,我们使用 `tf.train.import_meta_graph()` 方法加载了模型的图结构,并使用 `saver.restore()` 方法恢复了模型的参数。然后,我们通过 `graph.get_tensor_by_name()` 方法获取了模型的输入和输出张量,并执行了模型的前向传播操作,获取了预测结果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值