序言
最近好多项目都用到了yolov5,为了适配项目需求,有时候会对v5的代码部分做一些改动,写这篇文章的目的是为了记录自己改动的部分(本人很健忘),以便下次用的时候翻来看看。
一、数据增强部分增加了垂直旋转的增强
v5里面提供了很多数据增强方式,针对大众的数据集效果还是很适用的,如果觉得增强效果不理想,也可以自己在代码中增加,我这里就增加了一个垂直的旋转90°的,为了让数据集中包含了垂直的目标数据,比如之前做的卡片四个角检测,如果目标全是水平的卡片,那么检测时卡片位置变为垂直的时候,效果可能就会有折扣,但是这类增强并不是所有场景都适用,下图为垂直旋转前和垂直旋转后,针对于这类数据,该增强的效果还是很不错的,可以有效提高模型泛化能力。
代码改动部分:utils/datasets.py 里数据集加载的LoadImagesAndLabels.__getitem__部分(大概在530行的位置):
二、损失修改了置信度的赋值
因为在刚开始使用v5的时候,发现训练得到的置信度都偏低,然后检查了它的置信度部分的损失,直接给它赋1,本来只是打算尝试一下的,因为之前的v3就是这么干的,没想到效果还很不错?这部分代码在loss.py里的compute_loss函数里:
看下修改前和修改后的置信度对比:
可以看到修改后拟合出来的目标框置信度基本上都是接近1的置信度,因为检测任务比较简单,目标较明显,没修改的置信度要偏低一些,如果在检测时对置信度有要求的话不妨可以尝试一下。目前一直在这么用,个人感觉效果要比未修改前要好很多。
三、所有类别参与NMS
总所周知,多目标检测的NMS都是同类别的做,但是因为我平时用的场景偏向于检测文本之类的,不太可能会有不同的类叠加在一起。但是有时候又会出现这样的问题:
是不是有一些置信度低的框也被保留的,在OCR文本检测角度来说,一个框肯定只包含了一个类,所以其他置信度低的框与之重合的话都应该被过滤,这时候就可以把所有的类别都参与进来做NMS,修改部分在:utlis/general.py的non_max_suppression中
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
"""Performs Non-Maximum Suppression (NMS) on inference results
Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
"""
.
. 这里省略前面部分代码
.
# If none remain process next image
n = x.shape[0] # number of boxes
if not n:
continue
# ----------------------------------------
# 全部类别参与NMS,增加了这部分和下面那部分
class_name = x[:,5:6].clone()
x[:,5:6] = torch.zeros(x[:,5:6].size())
# Sort by confidence
# x = x[x[:, 4].argsort(descending=True)]
# ---------------------------------------
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
print(x, i, x.shape, i.shape)
pass
# -----------别忘了把类别赋值回去-----------
x[:,5:6] = class_name
# -------------------------------------------
output[xi] = x[i]
if (time.time() - t) > time_limit:
break # time limit exceeded
return output
全部NMS过滤后,干净了许多,当然也可以通过置信度阈值筛选把低置信度的检测框过滤,但是当训练不好的时候,置信度阈值调的太高,漏检的概率就会大大增加。