一、在tensorflow官网下载tflite模型和标签。以分类模型为例:
二、使用下载的模型和标签实现分类预测
整个项目结构:
import tensorflow as tf
import numpy as np
import cv2 as cv2
#读取标签文件,返回列表
def read_label_list():
with open('E:\PycharmProject\learning\classification03\labels_mobilenet_quant_v1_224.txt', 'r',encoding="utf8") as f:
data = f.read().splitlines()
return data
#图片处理,
def image_process(image_path):
image=cv2.imread(image_path)
image=cv2.resize(image,(224,224))
image=tf.convert_to_tensor(image)
image=tf.reshape(image,[1,224,224,3])
image = tf.cast(image, dtype=np.uint8)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
image = image.eval(session=sess) # 转化为numpy数组
return image
def main():
# 加载模型
interpreter = tf.lite.Interpreter(model_path="E:\PycharmProject\learning\classification03\mobilenet_v1_1.0_224_quant.tflite")
interpreter.allocate_tensors()
# 模型输入和输出细节
input_details = interpreter.get_input_details()
#print(input_details)
#[{'name': 'input', 'index': 88, 'shape': array([ 1, 224, 224, 3]), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
output_details = interpreter.get_output_details()
#print(output_details)
#[{'name': 'MobilenetV1/Predictions/Reshape_1', 'index': 87, 'shape': array([ 1, 1001]), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0)}]
#图片传入与处理
image_path=input('输入图片地址:')
image=image_process(image_path)
#模型预测
interpreter.set_tensor(input_details[0]['index'], image)#传入的数据必须为ndarray类型
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
#标签预测
w = np.argmax(output_data)#值最大的位置
lable_list = read_label_list()#读取标签列表
print(lable_list[w])
if __name__ == '__main__':
main()