上一篇博客https://blog.csdn.net/goodxin_ie/article/details/84315458我们详细介绍了pascal&&coco数据集,本篇我们将介绍pytorch如何加载
一、目标
pascal数据集的数据源是jpg图片,便签是xml文件,而pytorch运算使用的数据是Tensor。因此我们的目标是将jpg和xml文件转化为可供程序运算使用的Tensor或者numpy类型(Tesnor和numpy可以相互转化)。
回忆一下目标检测算法需要的标签信息,有类别和bbox框。在pascal数据集中,每张图片中的对象由xml中的objec标定,每个对象存在类别名name,位置框('ymin', 'xmin', 'ymax', 'xmax'),是否为困难样本的标记difficult。
二、解析xml文件
调用ElementTree元素树可以很方便的解析出xml文件的各种信息。我们主要使用其中的find方法查找对应属性的信息
ET.findall('object') #查找对象
ET.findall('bndbox') #查找位置框
完整的解析pasacal中xml文件代码如下:
输入参数:路径,文件名,是否使用困难样本
输出: bbox,label,difficult (类型np.float32)
def parseXml(data_dir,id,use_difficult=False):
anno = ET.parse(
os.path.join(data_dir, 'Annotations', id + '.xml'))
bbox = list()
label = list()
difficult = list()
for obj in anno.findall('object'):
if not use_difficult and int(obj.find('difficult').text) == 1:
continue
difficult.append(int(obj.find('difficult').text))
bndbox_anno = obj.find('bndbox')
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
name = obj.find('name').text.lower().strip()
label.append(VOC_BBOX_LABEL_NAMES.index(name))
bbox = np.stack(bbox).astype(np.float32) #from list to array
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool).astype(np.uint8) # PyTorch don't support np.bool
return bbox, label, difficult