TensorFlow 调用预训练好的模型—— Python 实现

1. 准备预训练好的模型

  • TensorFlow 预训练好的模型被保存为以下四个文件

模型文件

  • data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。
model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"

2. 导入模型图、参数值和相关变量

import tensorflow as tf
import numpy as np

sess = tf.Session()
X = None # input
yhat = None # output

def load_model():
    """
        Loading the pre-trained model and parameters.
    """
    global X, yhat
    modelpath = r'/home/senius/python/c_python/test/'
    saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')
    saver.restore(sess, tf.train.latest_checkpoint(modelpath))
    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name("X:0")
    yhat = graph.get_tensor_by_name("tanh:0")
    print('Successfully load the pre-trained model!')

  • 通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。

3. 运行前向传播过程得到预测值

def predict(txtdata):
    """
        Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).
        Test a single example.
        Arg:
                txtdata: Array in C.
        Returns:
            Three coordinates of a face normal.
    """
    global X, yhat

    data = np.array(txtdata)
    data = data.reshape(-1, 41, 41, 41, 3)
    output = sess.run(yhat, feed_dict={X: data})  # (-1, 3)
    output = output.reshape(-1, 1)
    ret = output.tolist()
    return ret

  • 通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。

4. 测试

load_model()
testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)
testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)
testdata = testdata[0:2, ...] # the first two examples
txtdata = testdata.tolist()
output = predict(txtdata)
print(output)
#  [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523], 
# [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
  • 本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。

获取更多精彩,请关注「seniusen」!
seniusen

### 回答1: 要调用训练好的TensorFlow模型,需要使用TensorFlow的API来加载模型并进行推理。具体步骤如下: 1. 导入TensorFlow库 ```python import tensorflow as tf ``` 2. 加载模型 ```python model = tf.keras.models.load_model('path/to/model') ``` 其中,`path/to/model`是训练好的模型文件的路径。 3. 进行推理 ```python result = model.predict(input_data) ``` 其中,`input_data`是输入模型的数据,`result`是模型的输出结果。 需要注意的是,加载模型时需要保证模型的结构和训练时一致,否则会出现错误。另外,推理时需要根据模型的输入和输出格式进行相应的数据处理。 ### 回答2: TensorFlow是目前最流行的深度学习框架之一,具有优秀的计算性能和灵活的开发能力。我们经常需要使用TensorFlow调用已经训练好的深度学习模型进行预测或分类任务。下面我将详细介绍如何使用TensorFlow调用训练好的模型。 特别需要注意的是,调用已训练好的模型需要依次完成以下三个步骤: 1. 加载模型 使用TensorFlow加载模型的方式有多种,本文将介绍其中常见的两种方式。 - 从文件中读取模型 使用TensorFlow训练模型时,会生成多个文件,包括模型的结构(.pb),变量的值(.ckpt),以及其他相关文件。我们可以通过tf.train.import_meta_graph()函数来将模型结构从.meta文件中读取出来,然后通过Saver.restore()函数来读取变量的值。 ``` python import tensorflow as tf # 模型路径 model_path = "model/" # 加载模型结构 graph = tf.Graph() with graph.as_default(): saver = tf.train.import_meta_graph(model_path + "model.ckpt.meta") # 加载模型参数 sess = tf.Session(graph=graph) saver.restore(sess, model_path + "model.ckpt") ``` - 直接从.pb文件中读取模型 如果我们直接使用freeze_graph.py将训练好的模型输出为.pb文件,则可直接通过tf.train.import_meta_graph()函数来加载模型。 ``` python import tensorflow as tf # 模型路径 model_path = "model.pb" # 读取模型 with tf.gfile.FastGFile(model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') ``` 2. 获取输入与输出节点 获取模型的输入和输出节点是调用已训练好的模型实现预测或分类的关键步骤。我们需要知道输入和输出节点的名称才能在代码中调用它们。一般可以通过如下位于.pb文件中的代码来查看模型的输入输出节点名称。 ``` python import tensorflow as tf from tensorflow.python.platform import gfile # 模型路径 model_path = "model.pb" # 加载模型 with tf.Session() as sess: #读取保存的模型文件 with gfile.FastGFile(model_path,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') # 遍历tensor,找到所有的op与tensor for index, t in enumerate(graph_def.node): print("tensor_name:", index, t.name) ``` 其中,模型的输入端一般为数据的placeholder节点,而输出节点则是输出的结果值。 ``` python import tensorflow as tf # 输入节点名称 input_tensor_name = "input:0" # 输出节点名称 output_tensor_name = "output:0" # 获取输入节点 input_tensor = tf.get_default_graph().get_tensor_by_name(input_tensor_name) # 获取输出节点 output_tensor = tf.get_default_graph().get_tensor_by_name(output_tensor_name) ``` 3. 运行预测 获取模型的输入输出节点后,我们可以使用Python代码来调用模型进行预测或分类任务了。这里的关键是要明确输入和输出张量的格式及数据类型。 ``` python import tensorflow as tf import numpy as np import cv2 # 输入数据 input_data = cv2.imread('1.jpg') input_data = cv2.resize(input_data, (224, 224), interpolation=cv2.INTER_CUBIC) input_data = np.expand_dims(input_data, axis=0) # 输入节点名称 input_tensor_name = "input:0" # 输出节点名称 output_tensor_name = "output:0" # 获取输入节点 input_tensor = tf.get_default_graph().get_tensor_by_name(input_tensor_name) # 获取输出节点 output_tensor = tf.get_default_graph().get_tensor_by_name(output_tensor_name) # 运行预测 with tf.Session() as sess: # 输出模型结果 result = sess.run(output_tensor, feed_dict={input_tensor: input_data}) print(result) ``` 需要注意的是,调用已训练好的模型进行预测时,需要提供与训练数据集相同的输入数据格式、数据类型。否则将可能得到不可预测的结果。在调试过程中,可以使用tf.print()函数输出中间过程的值,帮助定位问题。 总之,以上就是关于使用TensorFlow调用训练好的模型的具体步骤和方法。如果您还有任何疑问或需要帮助,请随时联系我们。 ### 回答3: TensorFlow是一个深度学习库,可以用于构建、训练和部署机器学习模型。在TensorFlow中,我们可以使用已经训练好的模型来进行预测任务。在下面的文章中,将介绍如何在TensorFlow调用训练好的模型。 1. 准备数据 在使用训练好的模型之前,我们需要准备输入数据。该数据应该与训练数据一样,包括特征和标签。特征应该是一系列数字或浮点数,而标签是一系列类别或数字。 2. 加载已训练的模型 在加载模型之前,我们需要知道模型保存在哪个路径中。如果您在训练模型时使用了TensorFlow保存模型的方法,那么模型应该保存在一个文件夹中,我们可以通过路径加载模型。 ``` import tensorflow as tf # 指定模型路径 model_path = './model' # 加载模型 model = tf.keras.models.load_model(model_path) ``` 3. 预测数据 已经加载了训练好的模型,可以使用模型对新的数据进行预测。我们可以将待预测的特征传递给模型来进行预测。 ``` import numpy as np # 加载数据 features = np.array([[1.2, 2.5, 3.7], [0.5, 0.9, 1.3]]) # 预测数据 predictions = model.predict(features) ``` 4. 输出结果 预测完成后,我们可以将结果打印出来。如果模型是用来分类的,那么输出值将是每个类别的概率值。如果模型是用来做回归的,那么输出将是预测值。 ``` print(predictions) ``` 以上是在TensorFlow调用训练好的模型的简单步骤。但实际应用中可能会因模型种类等不同因素而有所不同。需要依据具体情况进行调整。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值