pytorch官方实现FasterRCNN的步骤详解(一)——数据制作
# 读取解析PASCAL VOC2012 数据集
class VOC2012DataSet(Dataset):
"""读取解析PASCAL VOC2012 数据集"""
def __init__(self, voc_root=r'F:\data_set\VOCtrainval_11-May-2012', transforms=None, train_set=True):
"""
voc_root是在VOCdevkit前的目录,transforms为图片格式转换,与图像识别的格式转换中的水平翻转不一样,
目标检测需要把boxes也翻转,train_set 选择是否为训练集
:param voc_root:
:param transforms:
:param train_set:
"""
self.root = os.path.join(voc_root, 'VOCdevkit', 'VOC2012')
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012'"""
self.img_root = os.path.join(self.root, 'JPEGImages')
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\JPEGImages'"""
self.annotations_root = os.path.join(self.root, 'Annotations')
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations'"""
# 读取train.txt 或 val.txt 文件
if train_set:
txt_list = os.path.join(self.root, 'ImageSets', 'Main', 'train.txt')
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\ImageSets\Main\train.txt'"""
else:
txt_list = os.path.join(self.root, 'ImageSets', 'Main', 'val.txt')
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\ImageSets\Main\val.txt'"""
self.xml_list = []
with open(txt_list) as read:
for line in read.readlines():
self.xml_list.append(os.path.join(self.annotations_root, line.strip() + '.xml'))
# line.strip将换行符去掉
# 一次读取一整行
"""
r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000008.xml'
r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000015.xml'
r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000019.xml'
……
"""
json_file = open(r'D:\Python\text\python-learn\FasterRCNN\pascal_voc_classes.json', 'r')
self.class_dict = json.load(json_file)
# print(class_dict)
"""{
'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4, 'bottle': 5, 'bus': 6, 'car': 7, 'cat': 8,
'chair': 9, 'cow': 10, 'diningtable': 11, 'dog': 12, 'horse': 13, 'motorbike': 14, 'person': 15,
'pottedplant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20
}"""
self.transforms = transforms
def __len__(self):
return len(self.xml_list)
def __getitem__(self, idx):
# 读取xml文件
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
"""把全部数据一次性读取完"""
# with open(r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000008.xml') as fid:
# xml_str = fid.read()
# print(xml_str)
"""
# 这是xml文件,数据内容为:
<annotation>
<folder>VOC2012</folder>
<filename>2008_000008.jpg</filename>
<source>
<database>The VOC2008 Database</database>
<annotation>PASCAL VOC2008</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>442</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>horse</name>
<pose>Left</pose>
<truncated>0</truncated>
<occluded>1</occluded>
<bndbox>
<xmin>53</xmin>
<ymin>87</ymin>
<xmax>471</xmax>
<ymax>420</ymax>
</bndbox>
<difficult>0</difficult>
</object>
<object>
<name>person</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<occluded>0</occluded>
<bndbox>
<xmin>158</xmin>
<ymin>44</ymin>
<xmax>289</xmax>
<ymax>167</ymax>
</bndbox>
<difficult>0</difficult>
</object>
</annotation>"""
# 将xml文件解析为字典形式
xml = etree.fromstring(xml_str)
# print(xml)
"""<Element annotation at 0x20aa6c54280>"""
data = self.parse_xml_to_dict(xml)['annotation']
# print(data)
"""{
'folder': 'VOC2012',
'filename': '2008_000008.jpg',
'source': {'database': 'The VOC2008 Database', 'annotation': 'PASCAL VOC2008', 'image': 'flickr'},
'size': {'width': '500', 'height': '442', 'depth': '3'},
'segmented': '0',
'object':
[
{'name': 'horse',
'pose': 'Left',
'truncated': '0',
'occluded': '1',
'bndbox': {'xmin': '53', 'ymin': '87', 'xmax': '471', 'ymax': '420'},
'difficult': '0'}
,
{'name': 'person',
'pose': 'Unspecified',
'truncated': '1',
'occluded': '0',
'bndbox': {'xmin': '158', 'ymin': '44', 'xmax': '289', 'ymax': '167'},
'difficult': '0'}
]
}"""
img_path = os.path.join(self.img_root, str(data['filename']))
"""r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\JPEGImages\2008_000008.jpg'"""
image = Image.open(img_path)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
boxes = []
labels = []
iscrowd = []
for obj in data['object']:
# data['objext']为一个列表
xmin = float(obj['bndbox']['xmin'])
xmax = float(obj['bndbox']['xmax'])
ymin = float(obj['bndbox']['ymin'])
ymax = float(obj['bndbox']['ymax'])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj['name']])
iscrowd.append(int(obj['difficult']))
"""boxes = [[53, 87, 471, 420], [158, 44, 289, 167]]"""
"""labels = [13,15]"""
"""iscrowd = [0, 0]"""
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# print(boxes)
"""
tensor(
[[ 53., 87., 471., 420.],
[158., 44., 289., 167.]]
)
"""
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# print(area)
"""tensor([139194., 16113.])"""
target = {}
target['boxes'] = boxes
target['labels'] = labels
target['image_id'] = image_id
target['area'] = area
target['iscrowd'] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def parse_xml_to_dict(self, xml):
"""
将xml文件解析为字典形式
:param xml:型为<Element annotation at 0x20aa6c54280>
:return:
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # object可能有多个,所以需要放入列表中
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)['annotation']
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))