一、准备数据集
Tensorflow Object Detection API 用 TFRecord 文件格式读取数据,需把 VOC 格式的数据集进行转换(我自己的数据集是VOC2007)
1、修改 tensorflow/models/object_detection/create_pascal_tf_record.py 文件第84行和162行。
2、修改tensorflow/models/object_detection/data/pascal_label_map.pbtxt 文件里的类别.
3、运行命令:
# From tensorflow/models
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2007 --set=train \
--output_path=pascal_train.record
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2007 --set=val \
--output_path=pascal_val.record
执行后会在object_detection文件夹下生成pascal_train.record和pascal_val.record两个文件。
二、下载预训练模型
下载地址:https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md
解压命令例子:
tar -xzvf ssd_mobilenet_v1_coco.tar.gz
三、修改配置文件
修改 object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_pets.config文件:
(1)num_classes:修改为自己的classes num
(2)将所有PATH_TO_BE_CONFIGURED的地方修改为自己之前设置的路径(共5处)
四、训练
进入object_detection目录,运行:
tensorflow/models$ python object_detection/train.py --train_dir='/home/anngic/tensorflow/train' --pipeline_config_path='/home/anngic/tensorflow/models/object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config'
五、tensorboad
输入命令:
tensorboard --logdir=/home/shz/TF-OD-Test/train
在浏览器中输入https://0.0.0.0:6006,就能看到训练曲线了。