“ 一个人如果不能学会遗忘,那将是很痛苦的事,别再自寻烦恼,快把痛苦的事给忘了吧!”
为了能够使用Object Detection API~
需要将数据集格式转化为.TFRecord再进行训练~
至于,
如何使用Tensorflow官方的Object Detection API
包括下载、依赖(protobuf等)安装、跑demo、训练自己的数据过程~
推荐一篇博文: 1.https://blog.csdn.net/rookie_wei/article/details/81143814
2.https://blog.csdn.net/rookie_wei/article/details/81210499
3.https://blog.csdn.net/rookie_wei/article/details/81275663
整个过程比较详细,可以参考~
本篇主要介绍如何将已标注好的数据集转化成Tensorflow通用的.TFRecord格式~
注意:本程序是我自己检测的6类object,根据情况修改!
#-*- coding=utf-8 -*-
# File Name: Create_TFRecord.py
# Author: HZ
# Created Time: 2018-06-06
import os
import sys
import random
import numpy as np
import tensorflow as tf
import xml.etree.ElementTree as ET #操作xml文件
#我的标签定义有6类,根据自己的图片而定
VOC_LABELS = {
'none': (0, 'Background'),
'person': (1, 'Person'),
'car': (2, 'Car'),
'bus': (3, 'Bus'),
'truck': (4, 'Truck'),
'cyclist': (5, 'cyclist')
}
# 图片和标签存放的文件夹.
DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = 'JPEGImages/'
# 随机种子.
RANDOM_SEED = 4242
#生成整数型,浮点型和字符串型的属性
def int64_feature(value):
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value):
if n