本项目是在利用models/tensorflow中SSD_MobileNet网络模型进行的实验。
github链接:https://github.com/tensorflow/models/tree/master/research/object_detection
主要分为数据预处理、训练和测试部分(验证部分直接用github上的教程)
1. 数据预处理
在github教程上有对pets数据和pascalVOC数据的转.record文件的相关代码,但是由于一般普通数据通常只能获得.jpg(.jpeg)和.xml文件数据,那么不能直接使用github上现成的create_record代码,这里采用官方代码并进行了修改。
r"""Convert raw PASCAL dataset to TFRecord for object_detection.
Example usage:
./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \
--year=VOC2012 \
--output_path=/home/user/pascal.record
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import os
from lxml import etree
import PIL.Image
import tensorflow as tf
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='image'):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
data: dict holding PASCAL XML fields for a single image (obtained by
running dataset_util.recursive_parse_xml_to_dict)
dataset_directory: Path to root direct