SSD_pytorch 各种坑
声明:本文纯属学习记录
参考下面的代码
代码:https://github.com/amdegroot/ssd.pytorch
由于pytorch版本问题,bug太多了,参考了各种博客,各个博主总结的bug,挨个解决下去,终于跑通了!在这里谢谢各位博主大神!
代码讲解(不止这些)
https://zhuanlan.zhihu.com/p/195372992
https://www.cnblogs.com/cmai/p/10080005.html
找BUG引文如下:
https://www.jianshu.com/p/fb4338d5b800
https://zhuanlan.zhihu.com/p/101600509
找了最久的bug是“img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) IndexError: too many indices for array #430”
后来在github上找到了答案:
加了一个脚本(主要是数据集中有一些数据没有object)本代码剔除了这些数据样本:
import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
parser = argparse.ArgumentParser(
description=‘Single Shot MultiBox Detector Training With Pytorch’)
train_set = parser.add_mutually_exclusive_group()
parser.add_argument(’–root’, help=‘Dataset root directory path’)
args = parser.parse_args()
args.root=‘本人基础太差,高不明白args.root,所以这里加入了自己的地址’
CLASSES = ( # always index 0
‘aeroplane’, ‘bicycle’, ‘bird’, ‘boat’,
‘bottle’, ‘bus’, ‘car’, ‘cat’, ‘chair’,
‘cow’, ‘diningtable’, ‘dog’, ‘horse’,
‘motorbike’, ‘person’, ‘pottedplant’,
‘sheep’, ‘sofa’, ‘train’, ‘tvmonitor’)
annopath = osp.join(’%s’, ‘Annotations’, ‘%s.{}’.format(“xml”))
imgpath = osp.join(’%s’, ‘JPEGImages’, ‘%s.{}’.format(“jpg”))
def vocChecker(image_id, width, height, keep_difficult = False):
target = ET.parse(annopath % image_id).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
bndbox.append(cur_pt)
print(name)
label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
bndbox.append(label_idx)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
print(res)
try :
print(np.array(res)[:,4])
print(np.array(res)[:,:4])
except IndexError:
print("\nINDEX ERROR HERE !\n")
exit(0)
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
if name == ‘main’ :
i = 0
for name in sorted(os.listdir(osp.join(args.root, 'Annotations'))):
# as we have only one annotations file per image
i += 1
img = cv2.imread(imgpath % (args.root, name.split('.')[0]))
height, width, channels = img.shape
# res = vocChecker((args.root, name.split('.')[0]), height, width)
print("path : {}".format(annopath % (args.root, name.split('.')[0])))
res = vocChecker((args.root, name.split('.')[0]), height, width)
print("Total of annotations : {}".format(i))