(更新视频教程)Tensorflow object detection API 搭建属于自己的物体识别模型(2)——训练并使用自己的模型

本文详细介绍了如何使用Tensorflow Object Detection API搭建并训练物体识别模型,包括创建训练/测试数据集,配置文件,训练模型,以及测试模型的步骤。作者分享了从标注数据到使用自定义模型的全过程,并提供了相关资源链接和视频教程。
摘要由CSDN通过智能技术生成

2019.06.19

Tensorflow 学习交流6群

点击链接加入群聊【Tensorflow 学习交流6群】:https://jq.qq.com/?_wv=1027&k=5bBaezq

 

------------------------------------------------------------------------------------------------------------------------

2019.05.16

点击链接加入群聊【Tensorflow 学习交流5群】:https://jq.qq.com/?_wv=1027&k=5ZTCEf0

------------------------------------------------------------------------------------------------------------------------

2019.05.04

1,2,3群都已经满了,

4群号:576965570

由于QQ等级限制,3,4群上限只有200人,如果有读者能够帮忙创建更大的群,请与我联系,谢谢。

---------------------------------------------------------------------------------------------------------------

2019.03.11

1,2群均已满,3群已经创建,大家学习热情太高了,之前有的同学没有进来,感谢支持理解~

3群号:792673238

点击链接加入群聊【Tensorflow学习交流3群】:https://jq.qq.com/?_wv=1027&k=5IFb8he

----------------------------------------------------------------------------------------------------------------------------------------

2018.05.10

本人时差党,有时候回复不及时。创建了一个QQ群,方便大家互相学习交流。

---------------------------------------------------------------------------------------------------------------------------------------

 

2群号: 902067304

---------------------------------------------------------------------------------------------------------------

(1群人已满)点击链接加入群聊【Tensorflow学习交流群】:https://jq.qq.com/?_wv=1027&k=55j9V1r

 

------------------------------------------------------------------------------------------------

 

2018.05.04更新!

如何将训练好的模型移植到Android手机上:

https://blog.csdn.net/dy_guox/article/details/80192343

视频演示:

https://www.bilibili.com/video/av22957279/

-----------------------------------------------------------------------------------------------------

2018.04.02更新!

https://www.bilibili.com/video/av21539370/

系列操作视频已经上传,请有需要的读者自行前往。写博客的时候Tensorflow是1.4版本,视频里更新的是1.7版本,这中间遇到非常多的问题,加上第一次做视频,难免有很多问题,感谢理解!

 

 

https://blog.csdn.net/dy_guox/article/details/80139981

另外一个博客里更新了常见问题汇总,大家可以去看一下,欢迎分享或纠正!

 

------------------------------------------------------------------------------------------------------------------------------

 

在上一篇博客中(http://blog.csdn.net/dy_guox/article/details/79081499),我们成功安装了Tensorflow Object Detection API所需的开发环境,并在官方的Demo上成功进行了测试,接下来尝试运用自己的数据进行训练与测试。

 

项目代码汇总:

https://github.com/XiangGuo1992/Screen-Vehicle-Detection-using-Tensorflow-API

 

 

一、分析代码结构

仍然打开object_detection文件夹中的 object_detection_tutorial.ipynb ,分析代码结构。

第一部分Imports导入需要的包,不需要做更改。

 

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

if tf.__version__ < '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')

 

 

第二部分Env setup 设置系统环境,不必更改。

 

# This is needed to display the images.
%matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")

 

 

第三部分Object detection imports 导入Object detection 需要的模块,如果报错,说明工作目录设置不对,或者.../research以及.../research/slim 的环境变量没有设置好。

 

from utils import label_map_util

from utils import visualization_utils as vis_util


第四部分为设置模型的对应参数。

 

# 下载模型的名字
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection. 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90

github上有对应官方的各种模型(地址摸我),这些都是基于不用的数据集事先训练好的模型,下载好以后就可以直接调用。下载的文件以 '.tar.gz'结尾。'PATH_TO_CKPT'为‘.pb’文件的目录,'.pb'文件是训练好的模型(frozen detection graph),即用来预测时使用的模型。‘PATH_TO_LABELS’为标签文件,记录了哪些标签需要识别,'NUM_CLASSES'为类别的数目,根据实际需要修改。

 

见上图,第一列是模型名字,第二列是速度,第三列是精度。这里需要注意几点:

1、Model name上的名字与代码中“MODEL_NAME”后面变量的名字不一样,可以发现后者还有日期,在写代码的时候需要像后者那样将名字写完整,想得到完整的名字,可以直接在网站上点击对应的模型,弹出“另存为”对话框时就能够发现完整的MODEL_NAME”,如下图所示。

2、列表中速度快的模型,一般自己训练也会快,但是精度高的不一定使用自己的数据集时精度也高,因为训练的数据集及模型参数可能本身就存在差异,建议先用Demo中的‘ssd_mobilenet_v1_coco_2017_11_17’,速度最快。
 

第五部分Download Model 为下载模型,通过向对应网站发送请求进行下载解压操作。第六部分Load a (frozen) Tensorflow model into memory 将训练完的模型载入内存,第六部分Loading label map将标签map载入,这几个部分都不用修改,直接复制即可。

 

opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if 'frozen_inference_graph.pb' in file_name:
    tar_file.extract(file, os.getcwd())
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
基于TensorFlow Object Detection API搭建自己的物体识别模型的代码如下: 1. 准备工作: - 安装TensorFlow Object Detection API - 准备训练和测试数据集 - 下载预训练模型权重 2. 导入所需库: ```python import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 3. 加载label map和模型: ```python PATH_TO_LABELS = 'path_to_label_map.pbtxt' PATH_TO_MODEL = 'path_to_pretrained_model' label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 4. 定义函数进行物体识别: ```python def detect_objects(image): with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') image_expanded = np.expand_dims(image, axis=0) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) return image ``` 5. 加载测试图像并进行物体识别: ```python image = cv2.imread('test_image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) output_image = detect_objects(image) cv2.imshow('Object Detection', output_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过以上代码,可以使用自己的训练数据集、预训练模型权重和标签映射文件来搭建自己的物体识别模型。设置好路径并加载模型后,将待识别的图像传入`detect_objects`函数即可返回识别结果,并在图像上进行可视化展示。
评论 658
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值