深度学习:pc端测试tflite模型

  • 本文主要为大家提供一种将训练好的.h5模型转换为 .tflite模型后在PC端测试该模型效果的方法。

下面直接上代码:

import os
import cv2
import numpy as np
import tensorflow as tf
from config import *
import glob
from PIL import Image

# *****************************************测试图片、模型路径******************************************
# 图片集文件夹名称
testFileName = '1'
# 图片集路径
image_dir = os.path.join(os.getcwd(), '../needTestImages', testFileName)
# 得到图片集列表
img_name_list = glob.glob(image_dir + os.sep + '*')

# tflite模型路径
model_path = "../tools/model-simple_0423_float32.tflite"
# 加载并解析模型
interpreter = tf.lite.Interpreter(model_path=model_path)
# 分配空间
interpreter.allocate_tensors()

# 获取输入与输出的详细信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details =", input_details)
print("output_details =", output_details)


# ***********************************************归一化化输入*****************************************
def format_input(input_image):
    assert (input_image.size == DEFAULT_IN_SHAPE[:2] or input_image.shape == DEFAULT_IN_SHAPE)
    inputs = np.array(input_image)
    if inputs.shape[-1] == 4:
        input_image = input_image.convert('RGB')
    return np.expand_dims(np.array(input_image) / 255., 0)


# ***********************************************主函数**********************************************
def main():
    # 循环加载待检测图片
    for img in img_name_list:
        print("img =", img)
        start_time = time.time()
        # 将图片转化为RGB格式
        image = Image.open(img).convert('RGB')
        input_image = image

        # 将输入图片缩放到模型的输入尺寸大小
        if image.size != DEFAULT_IN_SHAPE:
            input_image = image.resize(DEFAULT_IN_SHAPE[:2], Image.BICUBIC)

        # 输入图片归一化处理并转化为需要的格式
        input_tensor = format_input(input_image)
        input_tensor = input_tensor.astype(np.float32)

        # 为分配的张量赋值
        index = input_details[0]['index']
        interpreter.set_tensor(index, input_tensor)

        # 调用解释器
        interpreter.invoke()
        # 获得输出
        print("output_details[0]['index'] =", interpreter.get_tensor(output_details[0]['index']))
        output_data = interpreter.get_tensor(output_details[0]['index'])[0][0]
        # 去除不需要的维度,并将数据转化为数组形式
        result = np.squeeze(output_data)
        output_mask = np.asarray(result)

        # 缩放到原图大小
        if image.size != DEFAULT_IN_SHAPE:
            output_mask = cv2.resize(output_mask, dsize=image.size)

        # 转化为3通道灰度图
        output_image = cv2.cvtColor(output_mask.astype('float32'), cv2.COLOR_BGR2RGB) * 255.
        print("host_time =", time.time() - start_time)
        # 转化为单通道灰度图
        output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2GRAY)
        output_location = output_dir.joinpath(pathlib.Path(img).name)
        cv2.imwrite(str(output_location), output_image)


if __name__ == '__main__':
    main()

结果:
原图
在这里插入图片描述
效果图:
在这里插入图片描述
下面是手机端的实测图片:
原图:
在这里插入图片描述

效果图:

在这里插入图片描述

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值