推荐参考:TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL
以VOC2007数据集为例:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from lxml import etree
from xml.etree import ElementTree
"""
VOC2007数据集格式:
└──VOCdevkit
└──VOC2007
└──JPEGImages
└──0.jpg
└──1.jpg
└──2.jpg
└──...
└──Annotations
└──0.xml
└──1.xml
└──2.xml
└──...
└──ImageSets
└──Main
└──train.txt
└──val.txt
└──trainval.txt
└──test.txt
"""
'''
xml文件信息(例):
<annotation>
<folder>JPEGImages</folder>
<filename>0.jpg</filename>
<path>X:/.../.../VOCdevkit/VOC2007/JPEGImages/0.jpg</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>600</width>
<height>800</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>peach</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>100</xmin>
<ymin>250</ymin>
<xmax>150</xmax>
<ymax>260</ymax>
</bndbox>
</object>
<object>
<name>cat</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>200</xmin>
<ymin>50</ymin>
<xmax>550</xmax>
<ymax>370</ymax>
</bndbox>
</object>
</annotation>
'''
class CustomDataset(Dataset): # 自定义数据集
def __init__(self, root, transforms=None, dataset_property="train"): # 初始化方法
self.root = root # 数据路径,应指向".../.../VOCdevkit"
self.transforms = transforms # 预处理方法,一般来说需要传入,注意区分训练数据和验证数据的预处理方法
self.images_dir = os.path.join(self.root,