1、概述
在上一篇我们已经成功的安装了tensorflow与tensorflow 物体检测的API。下面我们将实现使用自定义数据集训练自己的对象检测模型。完成此类工作大概需要以下6个步骤
1. 组织工程文件
2. 组织训练数据集与标注文件
3. 转化训练集为tf_record格式
4. 配置训练流程 pipeline
5. 监控模型训练过程
6. 保存模型参数
2、组织工程文件
2.1 新建workspace文件夹用于存储所有的工程文件
workspace 文件夹在object-detect文件夹下与存储tensorflow object detect api 的目录同级,目录结构如下所示
object-detect
├── models
│ ├── AUTHORS
│ ├── CODEOWNERS
│ ├── community
│ ├── CONTRIBUTING.md
│ ├── ISSUES.md
│ ├── LICENSE
│ ├── official
│ ├── orbit
│ ├── README.md
│ └── research
└── workspace
└── training_demo
2.2 在workspace文件夹下新建train_demo文件
该training_demo文件夹将是我们的训练文件夹,其中将包含与我们的模型训练有关的所有文件。每次我们希望在不同的数据集上进行训练时,建议创建一个单独的训练文件夹。培训文件夹的典型结构如下所示。
training_demo/
├─ addon
├─ annotations/
├─ exported-models/
├─ images/
│ ├─ test/
│ └─ train/
├─ models/
└─ pre-trained-models/
对于每个文件夹的作用说明如下
1. annotations:此文件夹将用于存储所有*.csv文件和各自的TensorFlow*.record文件,其中包含我们的数据集图像的标注列表。
2. exported-models:此文件夹将用于存储我最终模型
3. images:此文件夹包含我们数据集中所有图像的副本,以及*.xml为每个图像生成的相应文件
4. models:此文件夹将包含每个训练工作的子文件夹。每个子文件夹将包含训练流水线配置文件*.config,以及在训练和评估模型期间生成的所有文件。
5. pre-trained-models:此文件夹将包含下载的预训练模型,这些模型将用作我们训练工作的初始检查点。
6. addon 附加工具
在后面的程序编写过程中我们会对上面文件夹的描述有更为深刻的认识
3、准备数据集
3.1 安装标注工具
建议安装windows版本,从下面结果中安装windows最新版,无需安装,但放置的目录不能有中文。下载链接
http://tzutalin.github.io/labelImg/
首先需要将标注完的图片上传至
training_demo/
├─ images/
然后启动标注工具开始标注
3.2 分割数据集
完成对图像数据集的注释后,通常的惯例是仅将其中一部分用于训练,而其余部分用于评估。通常,比率为9:1,即90%的图像用于训练,其余的10%用于测试,但是您可以选择适合您需要的比率。分割的方式有两种,一种是手动分割,一种是程序分割,分割的比例是按照文件的数量而不是是实际标注的数量
import os
import re
from shutil import copyfile
import math
import random
def iterate_dir(source, dest, ratio, copy_xml):
source = source.replace('\\', '/')
dest = dest.replace('\\', '/')
train_dir = os.path.join(dest, 'train')
test_dir = os.path.join(dest, 'test')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(test_dir):
os.makedirs(test_dir)
images = [f for f in os.listdir(source)
if re.search(r'([a-zA-Z0-9\s_\\.\-\(\):])+(.jpg|.jpeg|.png)$', f)]
num_images = len(images)
num_test_images = math.ceil(ratio*num_images)
for i in range(num_test_images):
idx = random.randint(0, len(images)-1)
filename = images[idx]
copyfile(os.path.join(source, filename),
os.path.join(test_dir, filename))
if copy_xml:
xml_filename = os.path.splitext(filename)[0]+'.xml'
copyfile(os.path.join(source, xml_filename),
os.path.join(test_dir,xml_filename))
images.remove(images[idx])
for filename in images:
copyfile(os.path.join(source, filename),
os.path.join(train_dir, filename))
if copy_xml:
xml_filename = os.path.splitext(filename)[0]+'.xml'
copyfile(os.path.join(source, xml_f