tensorflow object detection API训练自己数据集中遇到的坑

前段时间一直使用pytorch,没怎么关注tensorflow,最近突然想试一下tensorflow的object detection API。

首先非常感谢https://blog.csdn.net/zj1131190425/article/details/80711857博主的分享,写得非常详细。

首先根据博主的第一篇博文下载了models的全部代码,这边可以完全参考博主。但其中有个坑:

》》到目前为止,最新版的是models是1.13.0版本需要tensorflow 1.12.0以上版本,而gpu版本的tensorflow需要cuda10,但我的电脑安装的是cuda9.0,为了不重装cuda,我下载了1.12.0的models,需要1.9.0以上版本,所以重新安装了tensorflow-gpu1.9.0版本的。

环境匹配好后,便可以进行调试。

这里用到的是object_detection_tutorial.ipynb文件作为demo,该文件在jupyter notebook中运行,并得到最终的demo检测图,如博主所得。

然后开始训练自己的数据,这里参考博主https://blog.csdn.net/zj1131190425/article/details/80778888,首先制作自己的数据集,将xml转为csv,然后转为tfrecord文件。然后修改config文件,下载模型。然后遇到了两个问题:

》》1. 我在pycharm中无法import出来object detection 包,添加环境变量也不行,后来在pycharm中新建了一个虚拟环境,并将object detection的路径添加进去,重启后可以正常运行。(这里似乎打开了用pycharm来管理环境的新世界,回头有时间研究一下)

》》2. 由于我使用的版本与博主使用的models版本不同,在我的版本下object detection中没有train文件,但在legacy中找到了train 和trainer。不过在object detection下有model_main.py文件,目测应该是进化版的train,但是里面有一个超参数hparams_overrides一直没搞明白什么意思,所以没用该文件,而是用train 训练的。

至此,成功实现用该API训练自己的数据集。这是用我的数据训练2000step的结果:大概跑了一个多小时(1066显卡)

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值