推荐参考:TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL
以VOC2007数据集为例:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from lxml import etree
from xml.etree import ElementTree
"""
VOC2007数据集格式:
└──VOCdevkit
└──VOC2007
└──JPEGImages
└──0.jpg
└──1.jpg
└──2.jpg
└──...
└──Annotations
└──0.xml
└──1.xml
└──2.xml
└──...
└──ImageSets
└──Main
└──train.txt
└──val.txt
└──trainval.txt
└──test.txt
"""
'''
xml文件信息(例):
<annotation>
<folder>JPEGImages</folder>
<filename>0.jpg</filename>
<path>X:/.../.../VOCdevkit/VOC2007/JPEGImages/0.jpg</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>600</width>
<height>800</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>peach</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>100</xmin>
<ymin>250</ymin>
<xmax>150</xmax>
<ymax>260</ymax>
</bndbox>
</object>
<object>
<name>cat</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>200</xmin>
<ymin>50</ymin>
<xmax>550</xmax>
<ymax>370</ymax>
</bndbox>
</object>
</annotation>
'''
class CustomDataset(Dataset): # 自定义数据集
def __init__(self, root, transforms=None, dataset_property="train"): # 初始化方法
self.root = root # 数据路径,应指向".../.../VOCdevkit"
self.transforms = transforms # 预处理方法,一般来说需要传入,注意区分训练数据和验证数据的预处理方法
self.images_dir = os.path.join(self.root, "VOC2007/JPEGImages") # 图像文件存储路径,默认VOC2007/JPEGImages
self.annotations_dir = os.path.join(self.root, "VOC2007/Annotations") # 标注文件存储路径,默认VOC2007/Annotations
self.imagesets_dir = os.path.join(self.root, "VOC2007/ImageSets/Main") # 数据集划分文件存储路径,VOC2007/ImageSets/Main
assert dataset_property in ["train", "val", "trainval", "test"] # 数据集划分文件名应该为train/val/trainval/test,以txt形式存储
with open(os.path.join(self.imagesets_dir, f"{dataset_property}.txt")) as f: # 打开对应txt文件
self.data_names = [i.strip() for i in f.readlines()] # 读取数据,存放于列表(每张图片的名称,不包含.jpg后缀)
self.label_dict = { # 语义标签与int的对应关系,一般从1开始,0表示背景(一般通过读取json或txt文件来获得对应关系)
"cat": 1,
"dog": 2,
"peach": 3
}
def __len__(self): # 获取数据集长度方法 **该方法必须定义**
return len(self.data_names) # 返回数据列表的长度
def __getitem__(self, index): # 采样方法,根据索引index取得对应的图像image和标签target **该方法必须定义**
image_path = os.path.join(self.images_dir, self.data_names[index] + ".jpg") # 图片路径,要求图片为jpg格式
image = Image.open(image_path).convert("RGB") # PIL图像
annotation_path = os.path.join(self.annotations_dir, self.data_names[index] + ".xml") # xml路径
label_names, difficults, boxes, areas = self.read_xml(annotation_path) # 解析xml文件,获取标注信息:标签名称、难识别目标标签、边界框、面积
labels = [self.label_dict[f"{label}"] for label in label_names] # 按照对应关系,将语义标签转化为int类型
labels, difficults, boxes, areas = map(lambda t: torch.as_tensor(t),
[labels, difficults, boxes, areas]) # 将标注信息转化为tensor形式
target = { # 构建最后返回的target,一般还包括image_id(图像id)、masks(分割掩码)、iscrowd(是否为多目标)
"boxes": boxes, # 边界框
"labels": labels, # 标签
"area": areas, # 面积
"isdifficult": difficults # 难识别目标标签
}
if self.transforms is not None: # 如果使用预处理方法
image, target = self.transforms(image, target) # 进行image和target的预处理,transforms函数/类需要复写,以满足对target的变换
return image, target # 返回image,target
def read_xml(self, annotation_path): # 读取xml信息方法
objnames = [] # 用于存放目标名称
difficults = [] # 用于存放难识别目标标签
objboxes = [] # 用于存放边界框
objareas = [] # 用于存放面积
parser = etree.XMLParser(encoding="utf-8") # xml文件解析器
xmlroot = ElementTree.parse(annotation_path, parser=parser).getroot() # 解析xml文件并获得root节点
for object in xmlroot.findall("object"): # 寻找所有object节点并遍历
objnames.append(object.find("name").text) # 获取name节点数据,填入列表
difficults.append(int(object.find("difficult").text)) # 获得difficult节点数据,转换为int,填入列表
objxmin = float(object.find("bndbox/xmin").text) # 获得bndbox/xmin节点数据,转换为float
objymin = float(object.find("bndbox/ymin").text) # 获得bndbox/ymin节点数据,转换为float
objxmax = float(object.find("bndbox/xmax").text) # 获得bndbox/xmax节点数据,转换为float
objymax = float(object.find("bndbox/ymax").text) # 获得bndbox/ymax节点数据,转换为float
assert objxmax > objxmin and objymax > objymin # 检查边界框的长宽是否为正
objboxes.append([objxmin, objymin, objxmax, objymax]) # 边界框[xmin,ymin,xmax,ymax],填入列表
objareas.append((objxmax - objxmin) * (objymax - objymin)) # 面积,填入列表
return objnames, difficults, objboxes, objareas # 返回目标名称、难识别目标标签、边界框、面积