【SCRDet++代码调试】损失很低,但检测效果并不好的问题
问题描述
- 复现 SCR Det++ 时,在 DOTA(crop) 数据集上训练模型7个小时,loss达到0.1493;但对模型进行 eval,精度仅为0.3%;test 也检测不到任何东西。
- 但是使用最原始的 Faster R-CNN 进行行人检测,仅仅训练2个小时,精度就能达到60%~80%。
原因分析
对上述现象进行分析,可能的原因有:
- 数据集本身的问题:由于遥感图像中目标太小(确实是太小太小了),所以对模型loss值的要求比自然场景下的要求高得多(也就是说,如果自然场景下的loss值为0.1时,检测效果会很好;那遥感图像就需要是0.01甚至更小);
【解决办法】:只能选择高性能GPU,增大epoch,慢慢训练。 - 网络结构太复杂的原因:虽然模型是基于 ResNet/VGG 进行的,且有预训练权重文件,但模型中还有很大一部分网络的权重并没有预训练,需要完全从头开始,所以需要的时候很长;
【解决办法】:只能选择高性能GPU,慢慢训练,得到整个网络的权重文件。 - 代码本身的问题:比如自己写的 eval.py 计算方法有误(但是训练完直接执行 test.py 确实检测不到东西)
- 暂时还没想出来的其他问题……
验证思路
1、数据集本身的问题?
首先,为了验证是不是遥感数据集本身的问题,下载了 Faster R-CNN 代码,并用3个不同的遥感数据集,验证在 Faster R-CNN 上的精度。
1)第1个数据集:DOTA
【结果】:不管如何选择,mAP都是0.0%。
- 使用全部数据集(crop后,14996张)训练4个epoch(坚持不到10个epoch就总是断网……),mAP=0.17%;
- 然后,考虑到 Faster R-CNN 或许对小目标检测效果不好,使用大目标(检测框面积>=100,000)训练10个epoch,mAP=0.38%;
- 最后,考虑到遥感图像中目标大多是俯视图,不同的目标形状相似(如篮球场、汽车等都是矩形),使用篮球场和飞机两个分类练10个epoch,mAP=0.01%。
2)第2个数据集:RSOD
然后,尝试了另一个数据集:RSOD,将其转为 VOC 格式。使用全部数据集(468张)训练4个epoch(不训练10个epoch的原因是,到了第4个,loss就nan了……),mAP=25.83%
3)第3个数据集:NWPU VHR-10
接着,对比 DOTA 和 RSOD 两个数据集的区别,觉得可能是水平/旋转检测框的问题。因此又选取了 NWPU VHR-10 数据集,使用650张图片训练10个epoch,mAP=69.73%
4)三个数据集分析
三个数据集都是“遥感数据集”,除了 DOTA,检测精度都还不错。所以检测精度不高不是因为遥感数据集中目标太小,更可能的是 DOTA 数据集本身的问题:
- 因为 DOTA 是旋转检测框?
- 因为 DOTA 被裁剪过?(但是那些没被裁剪的小目标,检测精度也不高啊……)
我觉得主要问题是旋转检测框。
DOTA(crop后) | RSOD | NWPU VHR-10 | |
---|---|---|---|
图片数量 | 14996 | 468 | 650 |
图片大小 | 800*800 | 1044*915 | 985*808 |
类别数 | 15 | 4 | 10 |
检测框 | 旋转 | 水平 | 水平 |
mAP | 0.17% | 25.83% | 69.73% |
FasterRCNN其实有能力精准的将所有的物体框住,只是框的重合度过高,在NMS中会有被过滤的风险
附:旋转检测框
在代码中,唯一对旋转检测框进行操作的地方是将旋转检测框(8个参数)转为水平检测框(4个参数)时的取点过程。为了验证取点前后检测框是否相差不大,用以下代码进行了验证:
import numpy as np
import cv2
import os
import xml.etree.ElementTree as ET
color_gt = (255, 0, 0)
color_md = (0, 255, 0)
thickness = 1
img_path = "E:\WorkSpace\PyCharmSpace\FPN\DOTA-DOAI\FPN_Tensorflow_Rotation\data\dataset_DOTA\DOTA1.0\\train-800\images"
file_path = "E:\WorkSpace\PyCharmSpace\FPN\DOTA-DOAI\FPN_Tensorflow_Rotation\data\dataset_DOTA\DOTA1.0\\train-800\labeltxt"
for f in os.listdir(file_path):
file = os.path.join(file_path, f)
tree = ET.parse(file)
root = tree.getroot()
size = tree.find('size')
width = float(size.find('width').text)
height = float(size.find('height').text)
objs = root.findall('object')
file_name = f.split('.')
img_source = os.path.join(img_path, file_name[0] + '.png')
img = cv2.imread(img_source, 0)
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
x1 = float(bbox.find('x0').text)
y1 = float(bbox.find('y0').text)
x2 = float(bbox.find('x1').text)
y2 = float(bbox.find('y1').text)
x3 = float(bbox.find('x2').text)
y3 = float(bbox.find('y2').text)
x4 = float(bbox.find('x3').text)
y4 = float(bbox.find('y3').text)
x_list = [x1, x2, x3, x4]
y_list = [y1, y2, y3, y4]
y_max = min(max(y_list), height)
y_min = max(min(y_list), 0)
x_max = min(max(x_list), width)
x_min = max(min(x_list), 0)
newBox = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
# 如果坐标点需要排序则使用下面函数
point = np.array(newBox).astype(int)
cv2.line(img, tuple(point[0]), tuple(point[1]), color_gt, thickness)
cv2.line(img, tuple(point[1]), tuple(point[2]), color_gt, thickness)
cv2.line(img, tuple(point[2]), tuple(point[3]), color_gt, thickness)
cv2.line(img, tuple(point[3]), tuple(point[0]), color_gt, thickness)
# -------------------------
ptLeftTop = (int(x_min), int(y_min))
ptRightBottom = (int(x_max), int(y_max))
cv2.rectangle(img, ptLeftTop, ptRightBottom, color_md)
image = np.expand_dims(img, axis=2)
image = np.concatenate((image, image, image), axis=-1)
cv2.imshow('image', img)
k = cv2.waitKey(0)
if k == 27: # wait for ESC key to exit
cv2.destroyAllWindows()
elif k == ord('s'): # wait for 's' key to save and exit
cv2.imwrite(file_name[0] + '.png', img)
cv2.destroyAllWindows()
下面是几张效果图:
在plane、small-vehicle、large-vehicle、ground-track-field
四类中,
- 只有
plane
这一类特征明显且几乎无遮挡(左下图); small-vehicle
目标太小且存在遮挡(左上图);large-vehicle
存在的遮挡问题非常明显(右上图),遮挡问题在ship、harbor、bridge
等类别上也非常明显;- 对于
ground-track-field
这一类大目标,经过裁剪后已经没有全局的特征了(右下图)。
2、网络结构的问题?
然后,为了验证是不是网络结构的问题,下载了 VEDAI 数据集,并将其转换为VOC格式,以供 SCR Det++ 网络训练用。
用 VEDAI 数据集训练,发现 faster_rcnn的两个loss一直为0,虽然最终经过训练loss能降到0.148,但是仍然检测不到任何东西。
两个公共数据集经过训练都检测不到任何东西,所以大概率是公开的 SCR Det++ 代码有问题(当然,公开代码只是部分代码,但是可能还有其他问题存在),需要自己实现某些部分。