【霹雳吧啦-目标检测篇】faster_RCNN (my_dataset.py)

b站链接:

2-自定义DataSet

结尾有整个文件的详细代码。

目录

b站链接:

1.建立类:

2.初始化函数:

 2.len 函数:

 3.  geiitem 函数:

4.获取高度宽度:

6.transforms代码:

 详细代码:


1.建立类:

class VOCDataSet(Dataset):

 此处建立了一个类( VOCDataSet ) , 此类继承了:

from torch.utils.data import Daraset

实现了 下文中的  len  和 getitem 两个方法。 

2.初始化函数:

def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):

voc_root : 训练集所在的根目录路径

year: 在视频中并没传入此参数,

            因为voc训练集有2007和2012 两个版本,此处我觉得是选则其中一个版本进行传入

transforms = None : 预处理方法,设置为True时会执行文件夹中的transforms.py文件

txt_name: str = "train.txt"  : 此处视频中使用的是  train_set = True ,意思是执行根目录下

                                              VOCdevkit文件夹下ImageSetsMain文件夹下的train.txt

                                              文件。对于新版可以对  val.txt  和 train.txt 文件进行选择。

----------------------------- ----------------------------- ----------------------------- -----------------------------   

assert year in ["2007", "2012"], "year must be in ['2007', '2012']"

 判断使用的是2007还是2012版本的VOC数据集,若不是会输出提示信息

----------------------------- ----------------------------- ----------------------------- -----------------------------   

# 增加容错能力
if "VOCdevkit" in voc_root:
    self.root = os.path.join(voc_root, f"VOC{year}")
else:
    self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")

    self.img_root = os.path.join(self.root, "JPEGImages")
    self.annotations_root = os.path.join(self.root, "Annotations")

对 VOCdevkit 文件夹是否在根目录下进行判断,获取 文件夹下的  VOC{year}  的地址存入

self.root,之后将图片(img_root)标签(annotation_root)的地址分别录入

 ----------------------------- ----------------------------- ----------------------------- -----------------------------  

 # read train.txt or val.txt file
 txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
 assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

此处就使用到了  txt_name ,获取val.txt  或 train.txt 的地址

 ----------------------------- ----------------------------- ----------------------------- -----------------------------  

with open(txt_path) as read:
    xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
    for line in read.readlines() if len(line.strip()) > 0]

self.xml_list = []

 读取 txt 文件,得到val.txt  或 train.txt 中每一个文件对应的 xml 文件,将所有 xml 文件信息保存到 xml_list 当中。

 ----------------------------- ----------------------------- ----------------------------- -----------------------------  

# 检查XML文件的存在性:如果文件不存在,则打印警告信息并跳过当前循环,继续处理下一个文件
for xml_path in xml_list:
    if os.path.exists(xml_path) is False:
        print(f"Warning: not found '{xml_path}', skip this annotation file.")
        continue

check file 进一步处理之前生成的XML文件列表, 检查这些XML文件的存在性和内容,并将满足条件的XML文件路径添加到另一个列表中。 同时,它还读取一个包含类别信息的JSON文件

 ----------------------------- ----------------------------- ----------------------------- -----------------------------  

            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
对于存在的 XML 文件,代码读取文件内容,然后使用  etree.fromstring  方法将其解析为一个XML  对象。接着,调用  self.parse_xml_to_dict  方法来获取XML数据,并从中提取出"annotation"部分

 ----------------------------- ----------------------------- ----------------------------- -----------------------------  

            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue
            self.xml_list.append(xml_path)
            # 如果XML文件存在且包含至少一个对象,它的路径将被添加到self.xml_list列表中

 代码检查解析后的数据是否包含"object"键。如果不包含,意味着该XML文件没有标注任何对象,因此打印信息并跳过当前循环

----------------------------- ----------------------------- ----------------------------- -----------------------------   

             # 断言检查
            """确保self.xml_list列表中有至少一个元素"""
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # 读取类别信息read class_indict
        json_file = './pascal_voc_classes.json'  # 该文件包含Pascal VOC数据集的类别信息
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

----------------------------- ----------------------------- ----------------------------- -----------------------------   

 2.len 函数:

    def __len__(self):
        return len(self.xml_list)

 3.  geiitem 函数:

def __getitem__(self, idx):

传入idx索引值,返回索引值的图片以及图片信息

----------------------------- ----------------------------- ----------------------------- -----------------------------  

xml_path = self.xml_list[idx]
with open(xml_path) as fid:
    xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]

 读取获得 xml 文件,并将xml文件中 annotation 下面的信息存入 data 中, 

xml文件如下:

----------------------------- ----------------------------- ----------------------------- -----------------------------  

img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
    raise ValueError("Image '{}' format not JPEG".format(img_path))

img_path:   获取图片的路径,通过根目录和 filename 中的内容拼接获得。

----------------------------- ----------------------------- ----------------------------- -----------------------------  

boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
    xmin = float(obj["bndbox"]["xmin"])
    xmax = float(obj["bndbox"]["xmax"])
    ymin = float(obj["bndbox"]["ymin"])
    ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
    if xmax <= xmin or ymax <= ymin:
         print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
         continue
    boxes.append([xmin, ymin, xmax, ymax])
    labels.append(self.class_dict[obj["name"]])

 box[ ]:中存储的是 object 下的 <bndbox> 信息

labels[ ] : 存储的是分类目标的索引值,即 <name>train<name> 中的 train 的索引值。

获取xml 中 object 的位置参数,得到XY的max和min。并获取每个 object 的类别名称name

 

----------------------------- ----------------------------- ----------------------------- -----------------------------  

if "difficult" in obj:   # object中有一个<difficult>0</difficult>,检测: 0 容易, 1困难
                iscrowd.append(int(obj["difficult"]))
else:
                iscrowd.append(0)

判断此目标是否是检测困难的目标 

----------------------------- ----------------------------- ----------------------------- -----------------------------   

# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])    # 计算面积
#      (   Ymax     -    Ymin    ) * (    Xmax    -   Xmin     )

转换数据为 tensor 类型,并计算 object 的面积

 ----------------------------- ----------------------------- ----------------------------- -----------------------------   

4.获取高度宽度:

def get_height_and_width(self, idx):
# read xml
   xml_path = self.xml_list[idx]
   with open(xml_path) as fid:
       xml_str = fid.read()
   xml = etree.fromstring(xml_str)
   data = self.parse_xml_to_dict(xml)["annotation"]
   data_height = int(data["size"]["height"])
   data_width = int(data["size"]["width"])
   return data_height, data_width

 ----------------------------- ----------------------------- ----------------------------- -----------------------------   

5.将xml转换为字典形式:

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式(以字典的形式存储),参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

 ----------------------------- ----------------------------- ----------------------------- -----------------------------   

6.transforms代码:

包含了准换 PIL 图像到 tensor, 随机水平翻转

import random
from torchvision.transforms import functional as F


class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

 

 详细代码:

import numpy as np
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree


class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []
        """
        check file
        进一步处理之前生成的XML文件列表,
        检查这些XML文件的存在性和内容,并将满足条件的XML文件路径添加到另一个列表中。
        同时,它还读取一个包含类别信息的JSON文件
        """
        # 检查XML文件的存在性:如果文件不存在,则打印警告信息并跳过当前循环,继续处理下一个文件。
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue

            # check for targets
            # 读取并解析XML文件
            """
            对于存在的XML文件,代码读取文件内容,然后使用etree.fromstring方法将其解析为一个XML对象。
            接着,它调用self.parse_xml_to_dict方法(这个方法没有在提供的代码段中定义,但可以假设它将XML对象转换为字典格式)来获取XML数据,
            并从中提取出"annotation"部分
            """
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]

            # 检查XML文件内容
            """代码检查解析后的数据是否包含"object"键。如果不包含,意味着该XML文件没有标注任何对象,因此打印信息并跳过当前循环"""
            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue

            # 添加满足条件的XML文件路径
            """如果XML文件存在且包含至少一个对象,它的路径将被添加到self.xml_list列表中"""
            self.xml_list.append(xml_path)

            # 断言检查
            """确保self.xml_list列表中有至少一个元素"""
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # 读取类别信息read class_indict
        json_file = './pascal_voc_classes.json'  # 该文件包含Pascal VOC数据集的类别信息
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):     # 只需要传入idx索引值,返回索引值的图片以及图片信息
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:   # object中有一个<difficult>0</difficult>,检测: 0 容易, 1困难
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])    # 计算面积
        #      (   Ymax     -    Ymin    ) * (    Xmax    -   Xmin     )

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式(以字典的形式存储),参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

    def coco_index(self, idx):
        """
        该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
        由于不用去读取图片,可大幅缩减统计时间

        Args:
            idx: 输入需要获取图像的索引
        """
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        # img_path = os.path.join(self.img_root, data["filename"])
        # image = Image.open(img_path)
        # if image.format != "JPEG":
        #     raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        return (data_height, data_width), target

  • 45
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值