tensorflow2 OD api训练自己的数据集

接上篇文章安装跑通环境,下来就是训练预测了

一、制作数据集

下载labelimg打标签(具体不多介绍,CSD上到处都是)

在object_detection下建立training_demo文件夹,作为根目录。然后在里面建立如下文件夹,多的后面自动会生成

annotations存放生成的数据,record格式的数据集,pbtxt文件标签,img预测使用的图片

pb.txt文件为自己构建类标签

exported-models存放训练好的模型、pre-trained-models存放预训练的模型、models存放预训练模型修改后的模型、images存放image打标签后的数据。image下包含train和test文件,将分类好的图片及其xml文件放入

在根目录下编译generate_tfrecord.py

 python generate_tfrecord.py -x C:\Users\Admin\Downloads\models-master\research\object_detection\training_demo\images\test -l C:\Users\Admin\Downloads\models-master\research\object_detection\training_demo\annotations\label_map.pbtxt -o C:\Users\Admin\Downloads\models-master\research\object_detection\training_demo\annotations\test.record
""" Sample TensorFlow XML-to-TFRecord converter

usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH]

optional arguments:
  -h, --help            show this help message and exit
  -x XML_DIR, --xml_dir XML_DIR
                        Path to the folder where the input .xml files are stored.
  -l LABELS_PATH, --labels_path LABELS_PATH
                        Path to the labels (.pbtxt) file.
  -o OUTPUT_PATH, --output_path OUTPUT_PATH
                        Path of output TFRecord (.record) file.
  -i IMAGE_DIR, --image_dir IMAGE_DIR
                        Path to the folder where the input image files are stored. Defaults to the same directory as XML_DIR.
  -c CSV_PATH, --csv_path CSV_PATH
                        Path of output .csv file. If none provided, then no file will be written.
"""

import os
import glob
import pandas as pd
import io
import xml.etree.ElementTree as ET
import argparse

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'    # Suppress TensorFlow logging (1)
import tensorflow.compat.v1 as tf
from PIL import Image
from object_detection.utils import dataset_util, label_map_util
from collections import namedtuple

# Initiate argument parser
parser = argparse.ArgumentParser(
    description="Sample TensorFlow XML-to-TFRecord converter")
parser.add_argument("-x",
                    "--xml_dir",
                    help="Path to the folder where the input .xml files are stored.",
                    type=str)
parser.add_argument("-l",
                    "--labels_path",
                    help="Path to the labels (.pbtxt) file.", type=str)
parser.add_argument("-o",
                    "--output_path",
                    help="Path of output TFRecord (.record) file.", type=str)
parser.add_argument("-i",
                    "--image_dir",
                    help="Path to the folder where the input image files are stored. "
                         "Defaults to the same directory as XML_DIR.",
                    type=str, default=None)
parser.add_argument("-c",
                    "--csv_path",
                    help="Path of output .csv file. If none provided, then no file will be "
                         "written.",
                    type=str, default=None)

args = parser.parse_args()

if args.image_dir is None:
    args.image_dir = args.xml_dir

label_map = label_map_util.load_labelmap(args.labels_path)
label_map_dict = label_map_util.get_label_map_dict(label_map)


def xml_to_csv(path):
    """Iterates through all .xml files (generated by labelImg) in a given directory and combines
    them in a single Pandas dataframe.

    Parameters:
    ----------
    
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
基于TensorFlow Object Detection API搭建自己的物体识别模型的代码如下: 1. 准备工作: - 安装TensorFlow Object Detection API - 准备训练和测试数据集 - 下载预训练的模型权重 2. 导入所需库: ```python import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 3. 加载label map和模型: ```python PATH_TO_LABELS = 'path_to_label_map.pbtxt' PATH_TO_MODEL = 'path_to_pretrained_model' label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 4. 定义函数进行物体识别: ```python def detect_objects(image): with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') image_expanded = np.expand_dims(image, axis=0) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) return image ``` 5. 加载测试图像并进行物体识别: ```python image = cv2.imread('test_image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) output_image = detect_objects(image) cv2.imshow('Object Detection', output_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过以上代码,可以使用自己的训练数据集、预训练模型权重和标签映射文件来搭建自己的物体识别模型。设置好路径并加载模型后,将待识别的图像传入`detect_objects`函数即可返回识别结果,并在图像上进行可视化展示。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值