前段时间一直使用pytorch,没怎么关注tensorflow,最近突然想试一下tensorflow的object detection API。
首先非常感谢https://blog.csdn.net/zj1131190425/article/details/80711857博主的分享,写得非常详细。
首先根据博主的第一篇博文下载了models的全部代码,这边可以完全参考博主。但其中有个坑:
》》到目前为止,最新版的是models是1.13.0版本需要tensorflow 1.12.0以上版本,而gpu版本的tensorflow需要cuda10,但我的电脑安装的是cuda9.0,为了不重装cuda,我下载了1.12.0的models,需要1.9.0以上版本,所以重新安装了tensorflow-gpu1.9.0版本的。
环境匹配好后,便可以进行调试。
这里用到的是object_detection_tutorial.ipynb文件作为demo,该文件在jupyter notebook中运行,并得到最终的demo检测图,如博主所得。
然后开始训练自己的数据,这里参考博主https://blog.csdn.net/zj1131190425/article/details/80778888,首先制作自己的数据集,将xml转为csv,然后转为tfrecord文件。然后修改config文件,下载模型。然后遇到了两个问题:
》》1. 我在pycharm中无法import出来object detection 包,添加环境变量也不行,后来在pycharm中新建了一个虚拟环境,并将object detection的路径添加进去,重启后可以正常运行。(这里似乎打开了用pycharm来管理环境的新世界,回头有时间研究一下)
》》2. 由于我使用的版本与博主使用的models版本不同,在我的版本下object detection中没有train文件,但在legacy中找到了train 和trainer。不过在object detection下有model_main.py文件,目测应该是进化版的train,但是里面有一个超参数hparams_overrides一直没搞明白什么意思,所以没用该文件,而是用train 训练的。
至此,成功实现用该API训练自己的数据集。这是用我的数据训练2000step的结果:大概跑了一个多小时(1066显卡)