PASCAL VOC2012 自定义Dataset

目录

1. PASCAL VOC2012文件架构

2. 读取VOC 数据

3. 自己的数据集自作


个人笔记

代码及资料来源PASCAL VOC2012数据集讲解与制作自己的数据集_哔哩哔哩_bilibili

1. PASCAL VOC2012文件架构

VOCdevkit
    └── VOC2012
         ├── Annotations               所有的图像标注信息(XML文件)
         ├── ImageSets    
         │   ├── Action                人的行为动作图像信息
         │   ├── Layout                人的各个部位图像信息
         │   │
         │   ├── Main                  目标检测分类图像信息
         │   │     ├── train.txt       训练集(5717)
         │   │     ├── val.txt         验证集(5823)
         │   │     └── trainval.txt    训练集+验证集(11540)
         │   │
         │   └── Segmentation          目标分割图像信息
         │         ├── train.txt       训练集(1464)
         │         ├── val.txt         验证集(1449)
         │         └── trainval.txt    训练集+验证集(2913)
         │ 
         ├── JPEGImages                所有图像文件
         ├── SegmentationClass         语义分割png图(基于类别)
         └── SegmentationObject        实例分割png图(基于目标)
  • train.txtval.txttrainval.txt文件里是对应标注文件的索引,每一行对应一个索引信息,也是一个图片的名称

  •  Annotations下一个XML文件对应一张图像的标注信息

XML标注文件中包含了 :

filename,通过在字段能够在JPEGImages 文件夹中能够找到对应的图片。

size记录了对应图像的宽、高以及channel信息。

每一个object代表一个目标,name===该目标的名称,pose===目标的姿势(朝向),truncated===目标是否完整,difficult===该目标的检测难易程度(0简单,1困难)

bndbox===边界框信息,是(xmin,ymin,xmax,ymax)左上角和右下角

  •  通过在标注文件中的filename字段在JPEGImages 文件夹中找到对应的图片。

2. 读取VOC 数据

代码如下:

 transforms定义:(faster_rcnn项目中的transforms.py)

目标检测,如果反转的话boxx也要反转

#     Compose 组合多个transform函数     ToTensor将PIL图像转为Tensor    RandomHorizontalFlip水平翻转 图像+++bboxes
import random
from torchvision.transforms import functional as F

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip(object):                       #  目标检测,还要翻转对应的bbox
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob
    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: batch, xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]    # 翻转对应bbox坐标信息,看PPT   bbox的维度:(个数,xmin,ymin,xmax,ymax)
            target["boxes"] = bbox                       # bbox原点在左上角
        return image, target

# bbox维度: batch, xmin, ymin, xmax, ymax
水平反转说明以下图为例:

“:”表示batch维度;“0 、2”代表Xmin  Xmax;      水平翻转Ymin Ymax不会变

新的Xmin==图像宽度-Xmax              代码:bbox[:, [0]] = width - bbox[:, [2]]

新的Xmax==图像宽度-Xmin              代码:bbox[:, [2]] = width - bbox[:, [0]]

综合:bbox[:, [0, 2]] = width - bbox[:, [2, 0]]

 

VOCDataSet定义:(faster_rcnn项目中的my_dataset.py)

import numpy as np
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


# """读取解析PASCAL VOC2007/2012数据集"""
class VOCDataSet(Dataset):
    # __init__搭建一些路径,定义东西self.-  self.-  self.-   方便后序__len__、__getitem__等方法调用

    # __init__根据传入的voc_root、transforms、txt_name参数生成以下:
    # self.root、self.img_root、self.annotations_root    定义VOC2012所在的路径、图像路径、标注路径
    # self.xml_list、self.class_dict、self.transforms    定义每张图像的标注路径放到列表、类别索引、数据预处理
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        #  找文件路径
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")  # VOC2012所在的路径
        self.img_root = os.path.join(self.root, "JPEGImages")              # 图像路径
        self.annotations_root = os.path.join(self.root, "Annotations")     #标注路径
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)  # 找到train.txt
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]
        # xml_list = [F:\data\voc_data\VOCdevkit\VOC2012\Annotations/2008_000008.xml, .....]

        # 下面这部分是将  xml_list中有Object的xml文件路径 给到 self.xml_list
        # 如果xml_list中每一个xml标注信息都有object, 那么self.xml_list==xml_list
        self.xml_list = []
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:  # 检查是否存在
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue
            with open(xml_path) as fid:             # 打开一个xml文件
                xml_str = fid.read()                # 把文档内的内容全部提取出来
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]  #将xml文件解析成字典形式----跳到函数去看看
            # #直接看遍历完的结果看data就行
            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue
            self.xml_list.append(xml_path)
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # 读取label dict
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)    #20个类别代表20类,下标1--20   因为0代表是背景,所以从下标1开始
        # 定义数据预处理
        self.transforms = transforms


    # __len__是返回样本数
    # 根据__init__定义好的self.root...self.xml_list...self.transforms
    # 列表self.xml_list存储   每一张图像对应的一个XML文件
    # 所以 列表self.xml_list长度===图像个数===样本数
    def __len__(self):
        return len(self.xml_list)


    # self是dataset中任何方法的第一个参数,也是必须传入的参数, __getitem__还需要传入一个index参数
    # 根据__init__定义好的self.root... 随机/按序 取index去找图像,最终返回需要的image、label
    def __getitem__(self, idx):
        xml_path = self.xml_list[idx]    # xml文件路径
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]          # xml文件内信息,转为为字典
        img_path = os.path.join(self.img_root, data["filename"])  # 图片路径
        # 读取image
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))
        # 提取label
        boxes = []
        labels = []
        iscrowd = []     # 是否难预测
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:  # 将一张图片中标注的一个或多个物体相关信息,存到上述3个列表
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)
        # 将label信息转化为tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])  # 计算面积
        # 整合label
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd    # 是否难预测
        # 数据预处理
        if self.transforms is not None:
            image, target = self.transforms(image, target)  # 看自定义的transforms,里面的函数需要传入 image和target
        # 结束,返回image, target
        return image, target


    # 自定义的获取图像高宽
    def get_height_and_width(self, idx):    # 通过read xml获取相关信息
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width


    # 自定义 将xml文件解析成字典形式
    def parse_xml_to_dict(self, xml):
        #xml:  xml tree obtained by parsing XML file contents using lxml.etree
        if len(xml) == 0:             # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}



    # 自定义的dataset好像最后都有这个,好像是 打包image和它的label
    # 需要时再详细查看
    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))

使用示例:

  • train/val时使用:
data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomHorizontalFlip(0.5)]),
                                      # 0.5随机反转的概率
    "val": transforms.Compose([transforms.ToTensor()])
}

train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")

train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                pin_memory=True,
                                                num_workers=0,
                                                collate_fn=train_dataset.collate_fn)
  • 随便取出几张看看并绘制框
train_data_set = VOCDataSet('F:\\data\\voc_data\\VOCdevkit', "2012", data_transform["train"], "train.txt")
#  train_data_set[0]就会自动调用__getitem__,返回索引0对应的image和label
#  0索引对应的图片===self.xml_list[0]标注信息的图片
#  回看__getitem__方法,顺序是self.xml_list[index]...打开...最后返回image和label

for index in random.sample(range(0, len(train_data_set)), k=2):  # 随机取出2个绘出  标框
    img, target = train_data_set[index]    # 此时是tensor格式
    img = ts.ToPILImage()(img)             # 变为图像格式
    plot_img = draw_objs(img,
                         target["boxes"].numpy(),
                         target["labels"].numpy(),
                         np.ones(target["labels"].shape[0]),
                         category_index=category_index, #字典--如下形式
    #{'1': 'aeroplane', '2': 'bicycle', ...,  '20': 'tvmonitor'}
                         box_thresh=0.5,
                         line_thickness=3,
                         font='arial.ttf',
                         font_size=20)
    plt.imshow(plot_img)
    plt.show()

上述代码中draw_objs 绘图函数(faster_rcnn项目中的draw_box_utils.py文件)

3. 自己的数据集自作

仿照2中代码构建自己的dataset;或者使用2中的代码 

若使用     2中的代码     需要和VOC一样的数据集存放架构(仔细看PASCAL VOC2012文件架构)

如果只有图像和标注信息,需要生成train.txt、val.txt

代码如下:(faster_rcnn项目中的split_data.py)

import os
import random

#*********************************   train.txt  val.txt 生成   ******************************************************

def main():
    random.seed(0)  # 设置随机种子,保证随机结果可复现

    files_path = "F:/data/voc_data/VOCdevkit/VOC2012/Annotations"
    assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)

    val_rate = 0.2

    files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
    files_num = len(files_name)
    val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        train_f = open("train.txt", "x")
        eval_f = open("val.txt", "x")
        train_f.write("\n".join(train_files))
        eval_f.write("\n".join(val_files))
    except FileExistsError as e:
        print(e)
        exit(1)


if __name__ == '__main__':
    main()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值