用Tensorflow Object Detection API 训练自己的数据集

一、准备数据集

Tensorflow Object Detection API 用 TFRecord 文件格式读取数据,需把 VOC 格式的数据集进行转换(我自己的数据集是VOC2007)

1、修改 tensorflow/models/object_detection/create_pascal_tf_record.py 文件第84行和162行。
这里写图片描述

这里写图片描述

2、修改tensorflow/models/object_detection/data/pascal_label_map.pbtxt 文件里的类别.
3、运行命令:

# From tensorflow/models
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \
    --data_dir=VOCdevkit --year=VOC2007 --set=train \
    --output_path=pascal_train.record
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \
    --data_dir=VOCdevkit --year=VOC2007 --set=val \
    --output_path=pascal_val.record

执行后会在object_detection文件夹下生成pascal_train.record和pascal_val.record两个文件。

二、下载预训练模型

下载地址:https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md
解压命令例子:

tar -xzvf ssd_mobilenet_v1_coco.tar.gz

三、修改配置文件

修改 object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_pets.config文件:

(1)num_classes:修改为自己的classes num

这里写图片描述

(2)将所有PATH_TO_BE_CONFIGURED的地方修改为自己之前设置的路径(共5处)

这里写图片描述

这里写图片描述

四、训练

进入object_detection目录,运行:

tensorflow/models$ python object_detection/train.py --train_dir='/home/anngic/tensorflow/train' --pipeline_config_path='/home/anngic/tensorflow/models/object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config'

五、tensorboad

输入命令:

tensorboard --logdir=/home/shz/TF-OD-Test/train

在浏览器中输入https://0.0.0.0:6006,就能看到训练曲线了。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值