Windows 利用tensorflow2 object-detection api训练自己的模型

预备工作:

windows tensorflow2 object-detection api 下载安装

根据这个文档一步步走

训练自定义对象检测器 — 张量流 2 对象检测 API 教程文档 (tensorflow-object-detection-api-tutorial.readthedocs.io)

1.准备数据集

需将数据集分为训练集和测试集两部分,使用labelImg工具打标签生成.xml文件,然后将.xml文件转成.record文件。

2.训练模型

models/tf2_detection_zoo.md at master · tensorflow/models · GitHub 在这里下载需要的模型,配置config文件时最好将batch_size的值改为1,以防显存不足,当然也可以自己设置一个合适的值。

可能遇到的报错:

self._read_buf = _pywrap_file_io.BufferedInputStream(
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xd5 in position 114: invalid continuation byte

可能的解决方法:

配置模型config文件时,将路径中的‘ \ '改成' /'.如:'.....training_demo/annotations/label_map.pbtxt'

路径中含有中文,需将路径改为全英文

AssertionError: Found 260 Python objects that were not bound to checkpointed values, likely due to changes in the Python program.

解决方法:将 fine_tune_checkpoint_type: "classification" 改为 fine_tune_checkpoint_type: "detection"

3.导出.pb模型

用文档的方法即可

4.测试训练好的模型

测试模型的代码转自:
原文链接:https://blog.csdn.net/weixin_48672949/article/details/118808852 

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Created on 2020.10.6
@auther:Jacklee
"""
import os

import cv2
import numpy as np
from PIL import Image
import tkinter
import matplotlib

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

import time
import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
from six import BytesIO


def load_image_into_numpy_array(path):
    """Load an image from file into a numpy array.
      Puts image into numpy array to feed into tensorflow graph.
      Note that by convention we put it into a numpy array with shape
      (height, width, channels), where channels=3 for RGB.
      Args:
        path: the file path to the image
      Returns:
        uint8 numpy array with shape (img_height, img_width, 3)
      """
    img_data = tf.io.gfile.GFile(path, 'rb').read()
    image = Image.open(BytesIO(img_data)).convert('RGB')
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)


# build detection model and load trained_model weights
pipeline_config = os.path.join(
    'D:/tensorflow-model/workspace/training_demo/exported-models/my_fast',
    'pipeline.config')
model_dir = 'D:/tensorflow-model/workspace/training_demo/exported-models/my_fast/checkpoint/'
print(pipeline_config)
print(model_dir)
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
    model_config=model_config, is_training=False)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()


def get_model_detection_function(model):
    """Get a tf.function for detection."""

    @tf.function
    def detect_fn(image):
        """Detect objects in image."""

        image, shapes = model.preprocess(image)
        prediction_dict = model.predict(image, shapes)
        detections = model.postprocess(prediction_dict, shapes)

        return detections, prediction_dict, tf.reshape(shapes, [-1])

    return detect_fn


detect_fn = get_model_detection_function(detection_model)

# Load label_map data
label_map_path = configs['eval_input_config'].label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
    label_map,
    max_num_classes=label_map_util.get_max_label_map_index(label_map),
    use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)

image_dir = 'D:/tensorflow-model/workspace/training_demo/images/1/'
PATH_TO_RESULT = 'D:/tensorflow-model/workspace/training_demo/images/fast/'
IMAGE_PATH_CHAR = []

start = time.time()

for image in os.listdir(image_dir):
    if image.endswith(".jpg") or image.endswith(".png"):
        IMAGE_PATH_CHAR.append(os.path.join(image_dir, image))  # 将每张图像的完整路径加入到列表中

for image_path in IMAGE_PATH_CHAR:
    print("Running inference for {}...".format(image_path), end='')
    image_np = load_image_into_numpy_array(image_path)
    input_tensor = tf.convert_to_tensor(
        np.expand_dims(image_np, 0), dtype=tf.float32)
    detections, predictions_dict, shapes = detect_fn(input_tensor)
    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detections['detection_boxes'][0].numpy(),
        (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
        detections['detection_scores'][0].numpy(),
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=200,
        min_score_thresh=.30,
        agnostic_mode=False)

    plt.figure(figsize=(12, 8))
    print(image_path.split('/')[-1])
    cv2.imwrite(PATH_TO_RESULT + image_path.split('/')[-1], image_np_with_detections)
    plt.subplot(122)
    plt.imshow(image_np_with_detections)
    plt.show()
end = time.time()
print('Execution Time: ', end - start)

上述代码需更改四处文件目录:

pipeline_config = os.path.join('training_mouse\\train_export',  'pipeline.config')
.pb文件位置:model_dir = 'training_mouse\\train_export\\checkpoint\\'                                            测试图片位置:image_dir = 'training_mouse\\test_image\\'
测试结果图:PATH_TO_RESULT = 'training_mouse\\test_image\\'

将该代码保存到training_demo文件夹下,cmd中cd到该目录下运行该.py文件。

5.参考博客

tensorflow2.0训练目标检测模型_weixin_48672949的博客-CSDN博客_tensorflow2.0 目标检测

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值