智能数字图像处理之FastRCNN(pytorch)代码解读之my_dataset.py

def __init__(self, voc_root, transforms, train_set=True):-》voc_root训练集所在根目录,transforms预处理方法,train_set boolean变量
        self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
        self.img_root = os.path.join(self.root, "JPEGImages")-》图像根目录
        self.annotations_root = os.path.join(self.root, "Annotations")-》标注信息根目录
        if train_set:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt")-》阅读train.txt文件
        else:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt")-》阅读var .txt文件
        with open(txt_list) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")-》打开txt文件读取它每一行保存为xml文件
                             for line in read.readlines()]

        # read class_indict
        try:
            json_file = open('./pascal_voc_classes.json', 'r')-》载入写有分类名称和索引的jason文件
            self.class_dict = json.load(json_file)-》加入到class_dict 这个变量当中
        except Exception as e:
            print(e)
            exit(-1)

        self.transforms = transforms


    def __len__(self):
        return len(self.xml_list)-》返回数据集文件的个数

 

def __getitem__(self, idx):-》idx为索引值
        # read xml
        xml_path = self.xml_list[idx]-》获取xml文件的路径
        with open(xml_path) as fid:-》打开xml文件
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)-》读取xml文件
        data = self.parse_xml_to_dict(xml)["annotation"]-》再将xml文件信息传入到parse_xml_to_dict(xml文件信息转化为字典)方法中
        img_path = os.path.join(self.img_root, data["filename"])-》拼接成图像路径
        image = Image.open(img_path)-》打开图片路径
        if image.format != "JPEG":
            raise ValueError("Image format not JPEG")-》如果不是jepg格式报错
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:-》遍历字典中的对象信息
            xmin = float(obj["bndbox"]["xmin"])-》获取xmin 的值,xmin:x坐标最小值
            xmax = float(obj["bndbox"]["xmax"])-》获取xmax的值,xmax:x坐标最大值
            ymin = float(obj["bndbox"]["ymin"])-》获取ymin的值,ymin:y坐标最小值
            ymax = float(obj["bndbox"]["ymax"])-》获取ymax的值,ymax:y坐标最小值
            boxes.append([xmin, ymin, xmax, ymax])-》添加到boxs的变量中
            labels.append(self.class_dict[obj["name"]])-》添加到labels的变量中
            iscrowd.append(int(obj["difficult"]))-》添加到iscrowd的变量中
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])-》把所有东西都转化成tensor,张量
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])-》画出目标区域

        target = {}-》创建当前检测的目标数组
        target["boxes"] = boxes-》添加boxes(框)
        target["labels"] = labels-》添加labels(标签)
        target["image_id"] = image_id-》添加image_id(图像id)
        target["area"] = area-》添加area(范围)
        target["iscrowd"] = iscrowd-》添加iscrowd(未知)

        if self.transforms is not None:-》是否进行数据处理
            image, target = self.transforms(image, target)

        return image, target

def get_height_and_width(self, idx):-》得到数据的行和列
        xml_path = self.xml_list[idx]-》读取xml路径
        with open(xml_path) as fid:-》打开xml
            xml_str = fid.read()-》读取xml
        xml = etree.fromstring(xml_str)-》转化为字符串
        data = self.parse_xml_to_dict(xml)["annotation"]-》获取xml中annotation这一节点值
        data_height = int(data["size"]["height"])-》获取xml中height这一节点值
        data_width = int(data["size"]["width"])-》获取xml中width这一节点值
        return data_height, data_width

 def parse_xml_to_dict(self, xml):-》xml转化为字典实现方法,将xml文件解析成字典形式,

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree


class VOC2012DataSet(Dataset):
    """读取解析PASCAL VOC2012数据集"""

    def __init__(self, voc_root, transforms, train_set=True):
        self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        if train_set:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt")
        else:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt")
        with open(txt_list) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                             for line in read.readlines()]

        # read class_indict
        try:
            json_file = open('./pascal_voc_classes.json', 'r')
            self.class_dict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)

        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}


# import transforms
# from draw_box_utils import draw_box
# from PIL import Image
# import json
# import matplotlib.pyplot as plt
# import torchvision.transforms as ts
# import random
#
# # read class_indict
# category_index = {}
# try:
#     json_file = open('./pascal_voc_classes.json', 'r')
#     class_dict = json.load(json_file)
#     category_index = {v: k for k, v in class_dict.items()}
# except Exception as e:
#     print(e)
#     exit(-1)
#
# data_transform = {
#     "train": transforms.Compose([transforms.ToTensor(),
#                                  transforms.RandomHorizontalFlip(0.5)]),
#     "val": transforms.Compose([transforms.ToTensor()])
# }
#
# # load train data set
# train_data_set = VOC2012DataSet(os.getcwd(), data_transform["train"], True)
# print(len(train_data_set))
# for index in random.sample(range(0, len(train_data_set)), k=5):
#     img, target = train_data_set[index]
#     img = ts.ToPILImage()(img)
#     draw_box(img,
#              target["boxes"].numpy(),
#              target["labels"].numpy(),
#              [1 for i in range(len(target["labels"].numpy()))],
#              category_index,
#              thresh=0.5,
#              line_thickness=5)
#     plt.imshow(img)
#     plt.show()

https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页