一、pascal_voc_parser.py——get_data
在train_frcnn.py中遇到的第一个问题是如何加载数据,代码如下:
# parser
#输入:
#数据集所在路径,这个是数据集所在路径,在路径下要包含VOC2012文件夹
#输出:
#all_imgs的每一项都包含['filepath','width','height','imageid','imageset','bbox'{'class_name','x1','y1','x2','y2','difficult'}]imageset表示是训练集还是测试集
#示例如下:
'''
all_img_data[0] = {'width': 500, 'height': 500,
'bboxes': [{'y2': 500, 'y1': 27, 'x2': 183, 'x1': 20, 'class': 'person', 'difficult': False},
{'y2': 500, 'y1': 2, 'x2': 249, 'x1': 112, 'class': 'person', 'difficult': False},
{'y2': 490, 'y1': 233, 'x2': 376, 'x1': 246, 'class': 'person', 'difficult': False},
{'y2': 468, 'y1': 319, 'x2': 356, 'x1': 231, 'class': 'chair', 'difficult': False},
{'y2': 450, 'y1': 314, 'x2': 58, 'x1': 1, 'class': 'chair', 'difficult': True}], 'imageset': 'test',
'filepath': './datasets/VOC2007/JPEGImages/000910.jpg'}
'''
#classes_count存放每类的标注框的数量
#e.g.{'sheep': 8, 'horse': 5, 'bg': 0, 'bicycle': 7, 'motorbike': 15, 'cow': 6, 'car': 34, 'aeroplane': 2, 'dog': 4, 'bus': 4, 'cat': 6, 'person': 113, 'train': 7, 'diningtable': 4, 'bottle': 3, 'sofa': 9, 'pottedplant': 7, 'tvmonitor': 7, 'chair': 27, 'bird': 6, 'boat': 7}
#classes_mapping存放样本类别数字和字符串的对应关系,例如'bg':0
all_imgs, classes_count, class_mapping = get_data(options.train_path)
本节将进入到get_data方法中查看如何加载数据
参考:http://geyao1995.com/Faster_rcnn%E4%BB%A3%E7%A0%81%E7%AC%94%E8%AE%B0_train_1/
get_data方法比较简单,就是讲VOC2012数据集中的xml文件给逐个解析出来,然后输出的是图像的信息,不是图像的像素数据,在处理图像问题时基本都应该只保存全部的图像信息,否则内存根本装不下那么多图像像素数据。
import os
import cv2
import xml.etree.ElementTree as ET
from tqdm import tqdm
#输入:
#数据集所在路径,这个是数据集所在路径,在路径下要包含VOC2012文件夹
#输出:
#all_imgs的每一项都包含['filepath','width','height','imageid','imageset','bbox'{'class_name','x1','y1','x2','y2','difficult'}]imageset表示是训练集还是测试集
#示例如下:
'''
all_img_data[0] = {'width': 500, 'height': 500,
'bboxes': [{'y2': 500, 'y1': 27, 'x2': 183, 'x1': 20, 'class': 'person', 'difficult': False},
{'y2': 500, 'y1': 2, 'x2': 249, 'x1': 112, 'class': 'person', 'difficult': False},
{'y2': 490, 'y1': 233, 'x2': 376, 'x1': 246, 'class': 'person', 'difficult': False},
{'y2': 468, 'y1': 319, 'x2': 356, 'x1': 231, 'class': 'chair', 'difficult': False},
{'y2': 450, 'y1': 314, 'x2': 58, 'x1': 1, 'class': 'chair', 'difficult': True}], 'imageset': 'test',
'filepath': './datasets/VOC2007/JPEGImages/000910.jpg'}
'''
#classes_count存放每类的标注框的数量
#e.g.{'sheep': 8, 'horse': 5, 'bg': 0, 'bicycle': 7, 'motorbike': 15, 'cow': 6, 'car': 34, 'aeroplane': 2, 'dog': 4, 'bus': 4, 'cat': 6, 'person': 113, 'train': 7, 'diningtable': 4, 'bottle': 3, 'sofa': 9, 'pottedplant': 7, 'tvmonitor': 7, 'chair': 27, 'bird': 6, 'boat': 7}
#classes_mapping存放样本类别数字和字符串的对应关系,例如'bg':0
def get_data(input_path):
all_imgs = []
classes_count = {}
class_mapping = {}
# parsing 정보 확인 Flag
visualise = False
# pascal voc directory + 2012
data_paths = [os.path.join(input_path,'VOC2012')]
print('Parsing annotation files')#解析注释文件
for data_path in data_paths:
print(data_paths)
annot_path = os.path.join(data_path, 'Annotations')
imgs_path = os.path.join(data_path, 'JPEGImages')
#ImageSets/Main directory
imgsets_path_trainval = os.path.join(data_path, 'ImageSets', 'Main', 'trainval.txt')
imgsets_path_train = os.path.join(data_path, 'ImageSets', 'Main', 'train.txt')
imgsets_path_val = os.path.join(data_path, 'ImageSets', 'Main', 'val.txt')
imgsets_path_test = os.path.join(data_path, 'ImageSets', 'Main', 'test.txt')
trainval_files = []
train_files = []
val_files = []
test_files = []
with open(imgsets_path_trainval) as f:
for line in f:
trainval_files.append(line.strip() + '.jpg')
with open(imgsets_path_train) as f:
for line in f:
train_files.append(line.strip() + '.jpg')
with open(imgsets_path_val) as f:
for line in f:
val_files.append(line.strip() + '.jpg')
# test-set not included in pascal VOC 2012
if os.path.isfile(imgsets_path_test):
with open(imgsets_path_test) as f:
for line in f:
test_files.append(line.strip() + '.jpg')
# 이미지셋 txt 파일 read 예외처리
# try:
# with open(imgsets_path_trainval) as f:
# for line in f:
# trainval_files.append(line.strip() + '.jpg')
# except Exception as e:
# print(e)
#
# try:
# with open(imgsets_path_test) as f:
# for line in f:
# test_files.append(line.strip() + '.jpg')
# except Exception as e:
# if data_path[-7:] == 'VOC2012':
# # this is expected, most pascal voc distibutions dont have the test.txt file
# pass
# else:
# print(e)
# annotation 파일 read
annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]#获取所有Annotations文件夹下的xml文件完整路径+文件名
idx = 0
annots = tqdm(annots)#显示进度条,生成一个tqdm对象
for annot in annots:
# try:
exist_flag = False
idx += 1
annots.set_description("Processing %s" % annot.split(os.sep)[-
最低0.47元/天 解锁文章
5294

被折叠的 条评论
为什么被折叠?



