文章目录
1 模型代码来源
我使用的模型是B站博主@霹雳吧啦Wz的代码,他写模型的读取方法时考虑的是pascal voc这个数据集,因此如果我们如果使用其他数据集,难免会遇见一些问题,这里记录了我是如何解决的。
ps: 如果使用ultralytics版本的yolov8,他们的代码会有忽略错误标框的能力。
2 对Wider Face数据集的训练处理
2.1 转换为voc格式
参考这篇的做法,一套流程下来,voc格式,coco格式和yolo格式都有了。
2.2 错误标框导致训练不能进行的问题
我一开始时用他仓库目录下的retinanet这个模型,代码有一定的容错能力,在碰见宽度/长度为0的框时会raise error,具体的代码在retinaNet/network_files/retinanet.py
的482-493
行
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
continue # <----这个continue是我加的,不然训练会停
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
所以我直接在循环条件里加了continue看看能不能强行训练,后面发现不行(到某个step损失为Inf),因为实际有问题的图片大概有30张,我的batch为8,若batch大于30或许能强行训练,总之需要挑出有问题的图片,并且这一步加的continue不要删。
2.2.1 挑出误标图片
直接上代码吧,最后会有一个err_xml.txt
文件显示有问题的图片。按照它,把voc格式的wider face的train.txt
中的对应路径给删除。