训练SSD网络的时候,出现报错:img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) IndexError: too many indices for array
IndexError: Too many indices for array:Array is 1-dimensional,but 2 were indexed
网上的解决方法有很多:SSD-pytorch 训练过程全记录_如何使ssd-pytorch记录训练的结果-CSDN博客训练ssd300时网络报错_导入ssd300出错_口在天上,数在心中的博客-CSDN博客【精选】SSD-Pytorch训练自己的VOC数据集&遇到的问题及解决办法_为什么制作好voc数据集运行代码,没有生成文件-CSDN博客【精选】SSD-Pytorch训练自己的VOC数据集&遇到的问题及解决办法_为什么制作好voc数据集运行代码,没有生成文件-CSDN博客最后发现是xml文件中有一个空object导致的,可以使用代码检查自己的xml文件或者一个一个打开看:
import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--root', default='/home/sd4t/why/workplace/ssd.pytorch-master/data/VOCdevkit/VOC2012' , help='Dataset root directory path')
args = parser.parse_args()
CLASSES = [('person')]
annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath = osp.join('%s', 'JPEGImages', '%s.{}'.format("png"))
def vocChecker(image_id, width, height, keep_difficult = False):
target = ET.parse(annopath % image_id).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
bndbox.append(cur_pt)
label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
bndbox.append(label_idx)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
try :
np.array(res)[:,4]
np.array(res)[:,:4]
except IndexError:
print(image_id+" had error index")
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
if __name__ == '__main__' :
i = 0
for name in sorted(os.listdir(osp.join(args.root,'Annotations'))):
# as we have only one annotations file per image
i += 1
img = cv2.imread(imgpath % (args.root,name.split('.')[0]))
height, width, channels = img.shape
res = vocChecker((args.root, name.split('.')[0]), height, width)
print("Total of annotations : {}".format(i))