【无标题】

pytorch官方实现FasterRCNN的步骤详解(一)——数据制作

# 读取解析PASCAL VOC2012 数据集
class VOC2012DataSet(Dataset):
    """读取解析PASCAL VOC2012 数据集"""
    def __init__(self, voc_root=r'F:\data_set\VOCtrainval_11-May-2012', transforms=None, train_set=True):
        """
        voc_root是在VOCdevkit前的目录,transforms为图片格式转换,与图像识别的格式转换中的水平翻转不一样,
        目标检测需要把boxes也翻转,train_set 选择是否为训练集
        :param voc_root:
        :param transforms:
        :param train_set:
        """
        self.root = os.path.join(voc_root, 'VOCdevkit', 'VOC2012')
        """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012'"""
        self.img_root = os.path.join(self.root, 'JPEGImages')
        """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\JPEGImages'"""
        self.annotations_root = os.path.join(self.root, 'Annotations')
        """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations'"""

        # 读取train.txt 或 val.txt 文件
        if train_set:
            txt_list = os.path.join(self.root, 'ImageSets', 'Main', 'train.txt')
            """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\ImageSets\Main\train.txt'"""
        else:
            txt_list = os.path.join(self.root, 'ImageSets', 'Main', 'val.txt')
            """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\ImageSets\Main\val.txt'"""

        self.xml_list = []
        with open(txt_list) as read:
            for line in read.readlines():
                self.xml_list.append(os.path.join(self.annotations_root, line.strip() + '.xml'))
            # line.strip将换行符去掉
            # 一次读取一整行
            """
            r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000008.xml'
            r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000015.xml'
            r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000019.xml'
            ……
            """
        json_file = open(r'D:\Python\text\python-learn\FasterRCNN\pascal_voc_classes.json', 'r')
        self.class_dict = json.load(json_file)
        # print(class_dict)
        """{
        'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4, 'bottle': 5, 'bus': 6, 'car': 7, 'cat': 8, 
        'chair': 9, 'cow': 10, 'diningtable': 11, 'dog': 12, 'horse': 13, 'motorbike': 14, 'person': 15, 
        'pottedplant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20
        }"""
        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        # 读取xml文件
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
            """把全部数据一次性读取完"""
        #    with open(r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations\2008_000008.xml') as fid:
        #        xml_str = fid.read()
        #        print(xml_str)
        """
        # 这是xml文件,数据内容为:
        <annotation>
            <folder>VOC2012</folder>
            <filename>2008_000008.jpg</filename>
            <source>
                <database>The VOC2008 Database</database>
                <annotation>PASCAL VOC2008</annotation>
                <image>flickr</image>
            </source>
            <size>
                <width>500</width>
                <height>442</height>
                <depth>3</depth>
            </size>
            <segmented>0</segmented>
            <object>
                <name>horse</name>
                <pose>Left</pose>
                <truncated>0</truncated>
                <occluded>1</occluded>
                <bndbox>
                    <xmin>53</xmin>
                    <ymin>87</ymin>
                    <xmax>471</xmax>
                    <ymax>420</ymax>
                </bndbox>
                <difficult>0</difficult>
            </object>
            <object>
                <name>person</name>
                <pose>Unspecified</pose>
                <truncated>1</truncated>
                <occluded>0</occluded>
                <bndbox>
                    <xmin>158</xmin>
                    <ymin>44</ymin>
                    <xmax>289</xmax>
                    <ymax>167</ymax>
                </bndbox>
                <difficult>0</difficult>
            </object>
        </annotation>"""
        # 将xml文件解析为字典形式
        xml = etree.fromstring(xml_str)
        # print(xml)
        """<Element annotation at 0x20aa6c54280>"""
        data = self.parse_xml_to_dict(xml)['annotation']
        # print(data)
        """{
        'folder': 'VOC2012', 
        'filename': '2008_000008.jpg', 
        'source': {'database': 'The VOC2008 Database', 'annotation': 'PASCAL VOC2008', 'image': 'flickr'}, 
        'size': {'width': '500', 'height': '442', 'depth': '3'}, 
        'segmented': '0', 
        'object': 
        [
        {'name': 'horse', 
        'pose': 'Left', 
        'truncated': '0',
        'occluded': '1', 
        'bndbox': {'xmin': '53', 'ymin': '87', 'xmax': '471', 'ymax': '420'}, 
        'difficult': '0'}
        , 
        {'name': 'person', 
        'pose': 'Unspecified', 
        'truncated': '1', 
        'occluded': '0', 
        'bndbox': {'xmin': '158', 'ymin': '44', 'xmax': '289', 'ymax': '167'}, 
        'difficult': '0'}
        ]
        }"""
        img_path = os.path.join(self.img_root, str(data['filename']))
        """r'F:\data_set\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\JPEGImages\2008_000008.jpg'"""
        image = Image.open(img_path)
        if image.format != 'JPEG':
            raise ValueError('Image format not JPEG')
        boxes = []
        labels = []
        iscrowd = []

        for obj in data['object']:
        	# data['objext']为一个列表
            xmin = float(obj['bndbox']['xmin'])
            xmax = float(obj['bndbox']['xmax'])
            ymin = float(obj['bndbox']['ymin'])
            ymax = float(obj['bndbox']['ymax'])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj['name']])
            iscrowd.append(int(obj['difficult']))
        """boxes = [[53, 87, 471, 420], [158, 44, 289, 167]]"""
        """labels = [13,15]"""
        """iscrowd = [0, 0]"""
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # print(boxes)
        """
        tensor(
        [[ 53.,  87., 471., 420.],
         [158.,  44., 289., 167.]]
         )
        """
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # print(area)
        """tensor([139194.,  16113.])"""

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = image_id
        target['area'] = area
        target['iscrowd'] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)
        return image, target

    
    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析为字典形式
        :param xml:型为<Element annotation at 0x20aa6c54280>
        :return:
        """
        if len(xml) == 0:  # 遍历到底层,直接返回tag对应信息
            return {xml.tag: xml.text}
        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # object可能有多个,所以需要放入列表中
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)['annotation']
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值