在《计算机视觉技巧合集(二)如何读取数据之目标检测篇》中VOC数据的读取只是简单读取Annotations文件夹下的标注文件和JPEGImages文件夹下的图像路径,并没有划分成训练模型时需要用到的训练集、验证集和测试集,因此,这里进行补充。
在VOC2012数据集的ImageSets\Main文件夹下,可以看到有很多已经划分好训练集、验证集和测试集的文本文件,当然我们也可以自己划分不同于官方的训练集、验证集和测试集,可以参考《计算机视觉技巧合集(一)如何读取数据》的划分方法,只需将划分好的图像名称按照VOC官方的保存形式保存在文本文件中即可。
以下是VOC数据集中已有的ImageSets\Main\train.txt:
2008_000008
2008_000015
2008_000019
2008_000023
......
以下是读取VOC数据集的示例代码:
import os
import numpy as np
import xml.etree.ElementTree as ET
class_names = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog','horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ]
def load_data_from_txt(text, img_root, anno_root, remove_difficult=False):
# 读取文本文件,获取图像名称列表
with open(text, 'r') as f:
img_names = f.readlines()
# 获取标注文件路径列表,由于读取的图像名称是"2008_000013\n"这种形式,因此使用strip()函数去除'\n'
anno_paths = [os.path.join(anno_root, img_name.strip()+".xml") for img_name in img_names]
# 获取图像路径列表
img_paths = [os.path.join(img_root, img_name.strip()+".jpg") for img_name in img_names]
all_labels = []
for anno_path in anno_paths:
target = ET.parse(anno_path)
root = target.getroot()
# 获得图像的高和宽
size = root.find("size")
h = int(size.find("height").text)
w = int(size.find("width").text)
# 获取这张图像中全部的标签(类别+真实框)
labels = []
for object in root.iter("object"):
# 获得辨认难度
difficult = int(object.find("difficult").text) == 1
# remove_difficult置1且difficult为1,那么跳过
if difficult and remove_difficult:
continue
# 获取类别索引
cls_name = object.find("name").text.strip()
cls_index = int(class_names.index(cls_name))
# 获取全部的标注真实框
bndbox = object.find("bndbox")
bbox = []
points = ['xmin', 'ymin', 'xmax', 'ymax']
for point in points:
pt = float(bndbox.find(point).text)
bbox.append(pt)
# 添加标签
label = [cls_index] + bbox
labels.append(label)
# 保证每张图像都有对应的标签,没有标签的图像生成一个全部值为0的标签,便于之后进行坐标转换
if len(labels) == 0:
labels = np.zeros((1, 5))
else:
labels = np.array(labels, dtype=np.float32)
# 返回全部标签
all_labels.append(labels)
return img_paths, all_labels
if __name__ == "__main__":
text_path = r"G:\datasets\VOCdevkit\VOC2012\ImageSets\Main\train.txt"
img_root = r"G:\datasets\VOCdevkit\VOC2012\JPEGImages"
anno_root = r"G:\datasets\VOCdevkit\VOC2012\Annotations"
img_paths, all_labels = load_data_from_txt(text_path, img_root, anno_root, remove_difficult=True)
print(f"图像总数: {len(img_paths)}")
print(f"标签总数: {len(all_labels)}")
# 展示前5个图像的标签
for index, labels in enumerate(all_labels):
print(f"第{index}张图像全部的标签:")
print(labels)
if index == 5:
break
程序运行结果如下: