参考官方给出的教程train_custom_data
- 配置环境, install yolox.
- 准备voc dataset, 调试train代码.
- 把自己的 dataset, 转换为voc格式, 调试train代码.
- 选择不同的模型结构, 训练最终可使用的model.
- 模型转换压缩, inference部署应用。
本文重点介绍dataset的处理
- install yolox
#服务器--pytorch环境
git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip3 install -v -e . # or python3 setup.py develop
conda create -n yolox python=3.7 #用python 3.8也可以
conda activate yolox
`#如果你切换了国内的源可以把后面的-c pytorch去掉。
conda install pytorch=1.7 torchvision cudatoolkit=10.2 -c pytorch``
git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip install -r requirements.txt
python setup.py develop
安装完成, 下载pretrained model yolox_s测试
python tools/demo.py image -f exps/default/yolox_s.py -c yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu
-
voc dataset
download voc dataset
作为初步验证, 仅使用了VOC2007的部分数据, 由于后面想做单类别的检测, 仅使用了voc中的car作为验证。
8张测试数据 提取码:wxyh
使用Python把原始数据中的car类别挑选出来
import xml.etree.ElementTree as ET
import os
def newImageSets(oldSets, newSets):
#保存含有car的文件名
savelist = []
with open(oldSets, 'r') as f:
for line in f.readlines():
ids = int(line)
path_i = 'Annotations/%06d.xml'%ids
if os.path.exists(path_i):
print(path_i)
savelist.append(line)
with open(newSets, 'a') as f1:
for id in savelist:
f1.write(id)
return
def selectCarAnn(srcAnnPath, dstAnnPath):
srcPath = os.path.join(srcAnnPath, "%06d.xml")
dstPath = os.path.join(dstAnnPath, "%06d.xml")
count = 0
#遍历所有标签文件
for id in range(1,9964):
_path = srcPath % id
rootTree = ET.parse(_path)
target = rootTree.getroot()
#判断此标签文件中, 是否有car
carFlag = False
for obj in target.iter("object"):
name = obj.find("name").text.strip()
if name == 'car':
carFlag = True
#print("name: ",_path,"--", name)
#如果有car, remove所有非car的标注box
if carFlag:
count += 1
#print(count)
#保存需要remove的非car物体
rm_list = []
for obj in target.iter("object"):
name = obj.find("name").text.strip()
if name != 'car':
rm_list.append(obj)
for o in rm_list:
target.remove(o)
rootTree.write(dstPath%id)
print(count)
return
def main():
selectCarAnn("Annotations", "Annotations_new")
newImageSets("ImageSets/Main/test.txt", "ImageSets/Main/test_new.txt")
return
if __name__ == '__main__':
main()
-
Hand dataset
download Hand dataset
此数据集是用matlab打包的, 这里使用python解析.mat数据文件如果想查看原始数据集的标注情况, 使用以下Python代码
import scipy.io as scio
import cv2
import random
import colorsys
import os
def loadbox(data):
out = []
for box in data['boxes'][0]:
p0 = box[0][0][0]
p1 = box[0][0][1]
p2 = box[0][0][2]
p3 = box[0][0][3]
res = []
res.append(p0[0])
res.append(p1[0])
res.append(p2[0])
res.append(p3[