【深度学习】tflite模型预测

一、在tensorflow官网下载tflite模型和标签。以分类模型为例:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、使用下载的模型和标签实现分类预测

  整个项目结构:
                  在这里插入图片描述

import tensorflow as tf
import numpy as np
import cv2 as cv2

#读取标签文件,返回列表
def read_label_list():
    with open('E:\PycharmProject\learning\classification03\labels_mobilenet_quant_v1_224.txt', 'r',encoding="utf8") as f:
        data = f.read().splitlines()
    return data
#图片处理,
def image_process(image_path):
    image=cv2.imread(image_path)
    image=cv2.resize(image,(224,224))
    image=tf.convert_to_tensor(image)
    image=tf.reshape(image,[1,224,224,3])
    image = tf.cast(image, dtype=np.uint8)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    image = image.eval(session=sess)  # 转化为numpy数组
    return image
def main():
    # 加载模型
    interpreter = tf.lite.Interpreter(model_path="E:\PycharmProject\learning\classification03\mobilenet_v1_1.0_224_quant.tflite")
    interpreter.allocate_tensors()
    
    
    # 模型输入和输出细节
    input_details = interpreter.get_input_details()
    #print(input_details)
    #[{'name': 'input', 'index': 88, 'shape': array([  1, 224, 224,   3]), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
    output_details = interpreter.get_output_details()
    #print(output_details)
    #[{'name': 'MobilenetV1/Predictions/Reshape_1', 'index': 87, 'shape': array([   1, 1001]), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0)}]
    

    #图片传入与处理
    image_path=input('输入图片地址:')
    image=image_process(image_path)

    #模型预测
    interpreter.set_tensor(input_details[0]['index'], image)#传入的数据必须为ndarray类型
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])

    #标签预测
    w = np.argmax(output_data)#值最大的位置
    lable_list = read_label_list()#读取标签列表
    print(lable_list[w])

if __name__ == '__main__':
    main()
TensorFlow是一个广泛用于机器学习和深度学习的开源框架。它提供了许多功能强大的工具和接口,使得加载和解析tflite模型变得相对简单。 要加载和解析tflite模型,首先需要使用TensorFlow提供的tflite模块。我们可以使用以下代码导入tflite模块: ``` import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="model.tflite") interpreter.allocate_tensors() ``` 在上述代码中,我们首先导入tensorflow模块,并创建了一个tf.lite.Interpreter对象。通过指定模型的路径"model.tflite",我们将tflite模型加载到内存中。然后,我们使用interpreter对象的allocate_tensors方法来为模型分配所需的张量。 加载完成后,可以使用interpreter对象的get_input_details和get_output_details方法获取模型的输入和输出张量的详细信息: ``` input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() ``` 通过打印input_details和output_details,我们可以获得输入和输出张量的名称、形状、数据类型等详细信息。 接下来,我们可以准备要输入模型的数据,并将其设置为输入张量的值: ``` input_data = ... interpreter.set_tensor(input_details[0]['index'], input_data) ``` 在上述代码中,我们将input_data设置为我们要输入模型的数据,并使用interpreter对象的set_tensor方法将其设置为输入张量的值。input_details[0]['index']表示输入张量的索引。 然后,我们可以使用interpreter对象的invoke方法来运行模型: ``` interpreter.invoke() ``` 运行模型后,我们可以通过获取输出张量的值来获取模型预测结果: ``` output_data = interpreter.get_tensor(output_details[0]['index']) ``` 将output_data打印出来,我们可以获得模型预测结果。 总结起来,使用TensorFlow加载和解析tflite模型的步骤包括:导入tflite模块、创建Interpreter对象并加载tflite模型、获取输入和输出张量的详细信息、设置输入张量的值、运行模型并获取输出张量的值。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值