上一篇文章配置好了适用于tensorflow2.x版本的Object_Detection_API,这篇文章将使用Lara数据集训练交通信号灯识别器。
1 下载数据集
下载Lara交通信号灯数据集,将其解压到traffic_lights文件夹里,包含图片文件夹和.txt的标签文件:
2 数据集处理
1)运行./traffic_lights文件夹里的preprocess.py,得到存在图片和标签文件的文件夹:
其中labels文件夹中的标签文件为单个图片对应的标签:
2)运行./traffic_lights文件夹里的txt2xml.py,将txt格式的标签转为xml格式,存放在./traffic_lights/Annotations文件夹中:
3)划分训练集和测试集:在./traffic_lights/Images文件夹里新建train和test文件夹,将7335张图片剪切到train文件夹,将剩下的剪切到test文件夹,则Image文件夹包含以下两个文件夹:
4)运行./traffic_lights文件夹里的xml_to_csv.py,将标签转为csv格式,运行完后,Images文件夹下生成以下两个文件:
5)运行./object_detection文件夹里的generate_tfrecord.py,将标签转为tfrecrd格式,运行完后,./object_detection/traffic_lights文件夹下生成以下两个文件:
3 生成标注映射图并配置训练
1)标注映射图(按上一篇博客下载的API中,已经操作好此步,可以省略)
标注映射图通过定义一个目标类别名到目标类别ID的映射来告诉训练器每一个物体是什么。在 .\object_detection\TL_training文件夹下新建一个文件并保存为labelmap.pbtxt。将以下内容复制到文本中。
item {
id: 1
name: 'Traffic_Light_go'
}
item {
id: 2
name: 'Traffic_Light_stop'
}
item {
id: 3
name: 'Traffic_Light_ambiguous'
}
item {
id: 4
name: 'Traffic_Light_warning'
}
2)配置训练(按上一篇博客下载的API中,已经操作好此步,可以省略)
最后,必须配置物体识别训练管道。它定义了哪些模型和参数将被用于训练。这是开始训练前的最后一步!
从.\research\object_detection\samples\configs文件夹中复制faster_rcnn_inception_v2_coco.config文件到\object_detection\TL_training路径下。然后用文本编辑器打开文件。需要对其做几处修改,主要修改类别和样本的数目,并添加文件路径到训练数据。
在faster_rcnn_inception_v2_coco.config文件中做如下几处修改。Note:路径必须使用单个正斜杠(/),而不是反斜杠(\)。同样,路径必须使用双引号(")而非单引号(')。
- Line 10. 修改num_classes为你想要识别的物体的类别数目。此交通信号灯识别器中num_classes是4。
- Line 107. 修改 fine_tune_checkpoint为:
- fine_tune_checkpoint : "D:/pycharm_community/tensorflow_object_detection_API/models_master/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/model.ckpt"
- Lines 122 and 124. 在train_input_reader模块中,修改input_path和label_map_path为:
- input_path : "D:/pycharm_community/tensorflow_object_detection_API/models_master/research/object_detection/traffic_lights/train.record"
- label_map_path: "D:/pycharm_community/tensorflow_object_detection_API/models_master/research/object_detection/TL_training/labelmap.pbtxt"
- Lines 136和138. 在eval_input_reader模块,修改input_path和label_map_path为:
- input_path : "D:/pycharm_community/tensorflow_object_detection_API/models_master/research/object_detection/traffic_lights/test.record"
- label_map_path: "D:/pycharm_community/tensorflow_object_detection_API/models_master/research/object_detection/TL_training/labelmap.pbtxt"
修改完毕后保存文件。
4 训练模型
运行./object_detection文件中的train_TL.py开始训练模型,训练的过程如下图所示:
可以通过TensorBoard观察训练过程。打开电脑终端,修改路径至D:\pycharm_community\tensorflow_object_detection_API\models_master\research\object_detection,并运行以下命令:
tensorboard --logdir=TL_training
这将在本地创建一个端口为6006的网页,可在浏览器上浏览。TensorBoard网页提供展示训练如何进行的信息和图表。
5 保存模型
在./object_detection文件夹下新建TL_model文件夹,将TL_training文件夹中的以下文件剪切到TL_model文件夹:
最后一步是生成冻结结论图(.pb文件)。在\object_detection路径下新建TL_inference_graph文件夹存在.pb文件。然后运行./object_detection下的export_TL_inference_graph.py,其中'trained_checkpoint_prefix'中的"model.ckpt-XXXX"中的"XXXX"需要被替换成TL_model文件夹下最高编号的.ckpt文件。运行成功后在 TL_inference_graph文件夹中生成以下文件:
6 测试模型
至此,交通信号灯识别器准备就绪,运行./object_detection中的TL_image_test.py对其进行测试,测试结果如下图所示(只训练了1000多步,效果不是太好):