基于Tensorflow2.2 object detection API使用自训练模型对检测目标做实际检测

基于Tensorflow2.2 object detection API使用自训练模型对检测目标做实际检测

1 前言

使用Tensorflow 2.2 object detection api训练自己的数据集,并对标定的目标物进行实际检测的完整过程终于到了最后一个步骤,即使用训练好的模型对测试集里的包含目标物的图像做实际的检测。本篇博客的主要内容包括使用训练好的CenterNet_ResNet50_v2导出模型对数据集里标定的目标做实际的检测,测试集里包括73张图片,这些图片未包含在训练集和评估集里,最终得到实际的检测效果。涉及的工作内容是在前面的成果基础上进行的,如果读者在看到这篇博客的内容并有一定的兴趣,请先查看前面博客的内容。另外单独将这部分实际检测的内容开一篇博客的原因是在写上一篇内容的时候还不能成功地使用训练好的模型做实际的检测,所以笔者也是在边学习摸索边进行自己的项目。在此过程中,发现网上的相关资料很少,搜索到的大部分是使用tensorflow1.x进行目标检测的内容,正因如此更加觉得有必要将一点点经验分享给大家,文章里的内容仅供参考,希望对正在使用Tensorfow 2.2 object detection api 的人能够有一点帮助,也希望大家共同进步!

2 下载训练模型

如果模型训练是在云端进行的,那么开始检测前就需要将训练好的模型下载下来,如果训练过程是基于本地则可以直接调用。训练好的模型导出格式和预训练模型是一样的,包含三个文件,如图所示。
在这里插入图片描述

3 测试集检测

模型训练的最终目的就是就是对你想检测的目标做实际的检测,我的测试集包含如下图所示的图片
在这里插入图片描述

4 测试效果

对以上的六张图像测试效果如下
在这里插入图片描述
图像中的目标物均被有效识别出来。

5 检测代码

检测代码并不唯一,但网上能跑通却很少,由于tensoeflow2.2 环境下检测代码和tensorflow 1.x有一些不同的地方,而网上大都数的检测代码是适用于tensorflow1.x的。tensorflow object detection的git hub账号有一个相关的检测推理部署教程,在这里一并贴出网址,可以根据自己的实际检测项目做简单的调整。我的代码也会在下面贴出,仅供参考!
在对数据集检测过程中遇到了一个错误,如下图所示,CenterNet_predict.py是我的检测程序,我的测试集是尺寸和通道就是(512, 512,3),一开始搞不懂到底哪里错误,反反复复修改代码,依旧报错!

File “CenterNet_predict.py”, line 37, in load_image_into_numpy_array
return np.array(image.getdata()).reshape(
ValueError: cannot reshape array of size 262144 into shape (512,512,3)

在这里插入图片描述
报错代码如下
在这里插入图片描述

查看官方的测试图像信息发现测试图像的位深是24,而我的图像位深是8,所以这段代码总是无法正常执行。
在这里插入图片描述
想要顺利执行,只好把图像位深转化为24位,而’RGB’格式的图像位深就是24位,于是稍作改动,代码如下

img_data = tf.io.gfile.GFile(path, 'rb').read()
    image = Image.open(BytesIO(img_data)).convert('RGB')
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)

完整转化代码贴在这里:

"""
Created on 2020.10.6
@auther:Jacklee
"""
import os

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import time
import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
from six import BytesIO


def load_image_into_numpy_array(path):
    """Load an image from file into a numpy array.

      Puts image into numpy array to feed into tensorflow graph.
      Note that by convention we put it into a numpy array with shape
      (height, width, channels), where channels=3 for RGB.

      Args:
        path: the file path to the image

      Returns:
        uint8 numpy array with shape (img_height, img_width, 3)
      """
    img_data = tf.io.gfile.GFile(path, 'rb').read()
    image = Image.open(BytesIO(img_data)).convert('RGB')
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)


#build detection model and load trained_model weights
pipeline_config = os.path.join('/home/lzy/models-master/research/object_detection/save_result_frozenlasercenternet',
                               'pipeline.config')
model_dir = '/home/lzy/models-master/research/object_detection/save_result_frozenlasercenternet/checkpoint/'

# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
      model_config=model_config, is_training=False)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()

def get_model_detection_function(model):
  """Get a tf.function for detection."""

  @tf.function
  def detect_fn(image):
    """Detect objects in image."""

    image, shapes = model.preprocess(image)
    prediction_dict = model.predict(image, shapes)
    detections = model.postprocess(prediction_dict, shapes)

    return detections, prediction_dict, tf.reshape(shapes, [-1])

  return detect_fn
detect_fn = get_model_detection_function(detection_model)

#Load label_map data
label_map_path = configs['eval_input_config'].label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
    label_map,
    max_num_classes=label_map_util.get_max_label_map_index(label_map),
    use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)



image_dir = '/home/lzy/models-master/research/object_detection/data_laser512+512/laser-pic/real_test/'
PATH_TO_RESULT = '/home/lzy/models-master/research/object_detection/data_laser512+512/laser-pic/real_test_result/'
IMAGE_PATH_CHAR = []

start = time.time()

for image in os.listdir(image_dir):
    if image.endswith(".jpg") or image.endswith(".png"):
        IMAGE_PATH_CHAR.append(os.path.join(image_dir, image))#将每张图像的完整路径加入到列表中

for image_path in IMAGE_PATH_CHAR:
    print("Running inference for {}...".format(image_path), end='')
    image_np = load_image_into_numpy_array(image_path)
    input_tensor = tf.convert_to_tensor(
         np.expand_dims(image_np, 0), dtype=tf.float32)
    detections, predictions_dict, shapes = detect_fn(input_tensor)
    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detections['detection_boxes'][0].numpy(),
        (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
        detections['detection_scores'][0].numpy(),
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=200,
        min_score_thresh=.30,
        agnostic_mode=False)
    plt.figure(figsize=(12, 8))
    print(image_path.split('/')[-1])
    cv2.imwrite(PATH_TO_RESULT + image_path.split('/')[-1], image_np_with_detections)
    #plt.imshow(image_np_with_detections)
    #plt.show()
end =time.time()
print('Execution Time: ', end-start)

Trained model for detection inferience

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值