之前用RFBNet进行目标检测,采用的数据集是VOC2007和VOC2012。最近用在自己的数据集进行训练,由于我的数据集格式跟VOC格式不一样,根据网上的经验,我就开始将自己的数据集制作成VOC格式的方便训练。但自己的数据集和标准的数据集质量真心不能比,有很多问题,花费了好多时间在数据处理上。。。
我遇到的问题主要是数据集的问题,而RFBNet是基于SSD的,所以SSD的如果出现这个问题大概率是一样的,当然其他目标检测网络也可以参考。
训练自己数据集
loss=nan问题
在制作完自己数据集后,训练RFBNet的时候,出现loss_l=nan的情况。
同时还出现RuntimeWarning:
Code/RFBNet-pytorch0.4.0/utils/box_utils.py:84: RuntimeWarning: invalid value encountered in true_divide
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
说明在代码的utils/box_utils.py的match_iou方法返回值中除法遇到了除数为0的情况。
解决方法
在网上查了一番,发现大家也都遇到类似的情况。
出现nan情况一般有以下集中可能:
- 数据问题,例如目标检测中可能出现bbox位置为(0,0),可能xmin>xmax
- 网络结构问题
- 训练问题等
在逼乎上某一大佬根据他自己经验给出一回答,将batch_size 调成1,shuffle调成False,查看到底哪些数据存在问题。受他启发,我也进行检查,将train_RFB.py中batch_size 调成1,shuffle调成False,ngpu、num_workers调成1,发现网络只是在某一iteration的时候出现loss=nan,所以基本确定是数据集的问题。
确定了是数据集的问题,还得确定是数据集的哪里出现了问题。这个过程实在是有点大海捞针的感觉,我试过将xml中bndbox值转成int,因为VOC中坐标值都是Int,object的name中特殊的字符映射到正常的字符等等。。。但这些都没能解决问题。
就在今天早上,读到一篇[微信推送]帮了我大忙。里面讲到bndbox的iou计算问题,其中有xmax-xmin,ymax-ymin。我就猜想会不会是这个问题。于是我先在VOC2007中检测是否存xmin>xmax的情况,结果发现没有!
然后我又在我的数据集中查找,结果发现,我去,1万条数据中有2000多条存在xmin>xmax!!!
检测xmin>xmax程序
import os
import xml.etree.ElementTree as ET
xml_dir = './Annotations'
def compare_min_max(xml_dir):
xmls = os.listdir(xml_dir)
xmls.sort()
flag = 0
count = 0
for xml in xmls:
xml_path = os.path.join(xml_dir, xml)
tree = ET.parse(xml_path)
root = tree.getroot()
for elem in root.findall('object'):
xmin = elem.find('bndbox').find('xmin').text
ymin = elem.find('bndbox').find('ymin').text
xmax = elem.find('bndbox').find('xmax').text
ymax = elem.find('bndbox').find('ymax').text
if int(ymin) > int(ymax) or int(xmin) > int(xmax):
print('min > max in file:',xml_path)
flag = 1
if flag == 1:
count += 1
flag = 0
print('{} files that min > max'.format(count))
print('finish comparision...')
if __name__ == '__main__':
compare_min_max(xml_dir)
在发现自己数据集存在这个问题后,我就重新制作了一遍数据集,这次数据集没有出现xmin(ymin)>xmax(ymax)的情况了。
制作数据集程序
##将数据集中img和rext中的信息转成VOC annotation的 xml格式
import os
import cv2
import io
import pandas as pd
img_path = r'./JPEGImages'
rect_path = r'./rect'
xml_path = r'./Annotations'
#read images ,get image's w,h c and name
def read_image(filename):
img = cv2.imread(filename)
h, w, c = img.shape
basename = os.path.basename(filename)
return h, w, c, basename
#读取rect的txt中第一个空行之前的内容
def file_reader(filename):
with open(filename) as f:
for line in f:
if line and line != '\n':
yield line
else:
break
#获取目标字符类别以及bbox
def get_object_bbox(filename):
bbox = []
data = io.StringIO(''.join(file_reader(filename)))
dataframe = pd.read_csv(data, skiprows=2, header=None)
for row in dataframe.iterrows():
if isinstance(row[1][2], str):
row[1][2] = row[1][2].strip() #delete space in string
r = [row[1][2],row[1][5], row[1][6], row[1][7], row[1][8]]
bbox.append(r)
return bbox #shape(n,5) n number of bndboxes, each bndbox has the form[object, xmin, ymin,xmax,ymax]
def write_xml(h, w, c, bbox, basename):
front, extend = os.path.splitext(basename)
front += '.xml'
full_path = os.path.join(xml_path, front)
with open(full_path,'w') as f:
f.write('<annotation>\n')
f.write(' <folder>OHWME</folder>\n')
f.write(' <filename>' + str(basename) + '</filename>\n')
f.write(' <source>\n')
f.write(' <database>MyDataBase</database>\n')
f.write(' <annotation>PASCAL VOC2007</annotation>\n')
f.write(' <image>f</image>\n')
f.write(' </source>\n')
f.write(' <size>\n')
f.write(' <width>' + str(w) + '</width>\n')
f.write(' <height>' + str(h) + '</height>\n')
f.write(' <depth>' + str(c) + '</depth>\n')
f.write(' </size>\n')
f.write(' <segmented>0</segmented>\n')
for b in bbox:
object = b[0]
if object == '/':
object = r'\backslash'
if object == '.':
object = r'\dot'
xmin = b[1]
ymin = b[2]
xmax = b[3]
ymax = b[4]
f.write(' <object>\n')
f.write(' <name>' + str(object) + '</name>\n')
f.write(' <pose>Unspecified</pose>\n')
f.write(' <truncated>0</truncated>\n')
f.write(' <difficult>0</difficult>\n')
f.write(' <bndbox>\n')
##avoid xmin,ymin > xmax,ymax
if int(xmin) > int(xmax):
xmax, xmin = xmin, xmax
if int(ymin) > int(ymax):
ymax, ymin = ymin, ymax
#avoid (0,0) which would probaly result in nan
if int(xmin) < 1:
f.write(' <xmin>' + str(int(xmin + 1)) + '</xmin>\n')
else:
f.write(' <xmin>' + str(int(xmin)) + '</xmin>\n')
if int(ymin) < 1:
f.write(' <ymin>' + str(int(ymin + 1)) + '</ymin>\n')
else:
f.write(' <ymin>' + str(int(ymin)) + '</ymin>\n')
if int(xmax < 1):
f.write(' <xmax>' + str(int(xmax + 1)) + '</xmax>\n')
else:
f.write(' <xmax>' + str(int(xmax)) + '</xmax>\n')
if int(ymax < 1):
f.write(' <ymax>' + str(int(ymax + 1)) + '</ymax>\n')
else:
f.write(' <ymax>' + str(int(ymax)) + '</ymax>\n')
f.write(' </bndbox>\n')
f.write(' </object>\n')
f.write('</annotation>')
if __name__ == '__main__':
img_names = os.listdir(img_path)
rect_names = os.listdir(rect_path)
img_names.sort()
rect_names.sort()
for img_name, rect_name in zip(img_names, rect_names):
full_image_path = os.path.join(img_path, img_name)
full_rect_path = os.path.join(rect_path, rect_name)
h, w, c, basename = read_image(full_image_path)
bbox = get_object_bbox(full_rect_path)
print('writing {}\\{}.xml'.format(xml_path,os.path.splitext(basename)[0]))
write_xml(h, w, c, bbox, basename)
训练结果
重新制作数据集后,训练过程如下:
从图中可以看出loss_l L已经正常并且开始下降,说明数据集格式正确了,结果由于网络还在训练,所以还没有test结果,但至少说明开始正确训练了,至于mAP能有多少还得调参哈哈哈(请叫我调参侠)。
总结
目标检测中若遇到loss为nan的情况,
首先,检查数据集格式问题。如bbox的xmin,ymin是否大于xmax,ymax,或者坐标是否存在为0的情况。
其次,检查网络结构是否存在问题。
还有,训练的方法是否有问题。
希望我的经验能帮助遇到类似问题的朋友,少掉点头发,少走点弯路。
参考
https://www.zhihu.com/question/49346370
https://mp.weixin.qq.com/s/TMRDhDrf5rRRFIdGGL8Uhg