使用 PyTorch 处理CUB200_2011数据集

前言

最近在研究一些深度学习的东西,想找一些数据集,第一眼看到 CUB 的时候,就相中它了,数据集大小合适,类别还很多,关键还提供了目标检测的边框,不仅可以做分类,还可以做检测(虽然只是单目标检测~),感觉很酷,想研究一下。

本文主要参考 CUB200_2011数据集处理,感觉这篇文章文字虽然少,但代码写得却很好,大赞

数据集介绍

CUB 数据集一共 200 个类别,共 11788 张图片,每张图片除包括类别标签外,还有一个标注的物体边框(Bounding Box)、关键点和一些其他属性,算是一个很细粒度的图像分类的数据集了。

CUB 数据集有 2010 版和 2011 版,在 官网 界面可以选择,本文处理的是 2011 版的。

这是数据集的下载地址:CUB_200_2011.tgz

下载后,需要解压两次,解压后的文件夹内部如图所示:

CUB_200_2011数据集文件夹
其中,不考虑属性信息,主要有五个说明文档

  • bounding_boxes.txt : 包含每张图像的物体边框,格式为 <image_id> <x> <y> <width> <height>

  • classes.txt : 包含每张图片的类别序号和名称,格式为 <class_id> <class_name>

  • image_class_labels.txt : 包含每张图片对应的类别序号,格式为 <image_id> <class_id>

  • images.txt : 包含每张图片的路径信息,格式为 <image_id> <image_name>

  • train_test_split.txt : 记录数据集的训练集和测试集划分,格式为 <image_id> <is_training_image>

代码

使用 pytorch 制作自定义数据集类别时,需要继承 Dataset 类,实现 __getitem__ 方法和 __len__ 方法。

代码主要想是放到 AlexNet 做训练,但是 AlexNet 输入为 224 × 224 224\times224 224×224 或是 227 × 227 227\times227 227×227,而数据集里的图片基本都比这个格式大,因此,我选择对每张图片提取 bounding_boxes 边框,并对边框里的图像进行 resize 操作(直接resize会留有较大部分背景区域)

步骤是分别读取四个配置文件(classes.txt没有用到),然后读取图片,并对每张图片做处理(有的图片是单通道的,使用 convert 转成 ‘RGB’ 格式)。

全部代码如下:

from torch.utils.data import Dataset
import os
from PIL import Image

class CUB200(Dataset):
    def __init__(self, root, image_size=227, train=True, transform=None, target_transform=None):
        '''
        从文件中读取图像,数据
        '''
        self.root = root  # 数据集路径
        self.image_size = image_size  # 图像大小(正方形)
        self.transform = transform  # 图像的 transform 
        self.target_transform = target_transform  # 标签的 transform 

        # 构造数据集参数的各文件路径
        self.classes_file = os.path.join(root, 'classes.txt')  # <class_id> <class_name>
        self.image_class_labels_file = os.path.join(root, 'image_class_labels.txt')  # <image_id> <class_id>
        self.images_file = os.path.join(root, 'images.txt')  # <image_id> <image_name>
        self.train_test_split_file = os.path.join(root, 'train_test_split.txt')  # <image_id> <is_training_image>
        self.bounding_boxes_file = os.path.join(root, 'bounding_boxes.txt')  # <image_id> <x> <y> <width> <height>

        imgs_name_train, imgs_name_test, imgs_label_train, imgs_label_test, imgs_bbox_train, imgs_bbox_test = self._get_img_attributes()
        if train: # 读取训练集
            self.data = self._get_imgs(imgs_name_train, imgs_bbox_train)
            self.label = imgs_label_train
        else: # 读取测试集
            self.data = self._get_imgs(imgs_name_test, imgs_bbox_test)
            self.label = imgs_label_test

    def _get_img_id(self):
        ''' 读取张图片的 id,并根据 id 划分为测试集和训练集 '''
        imgs_id_train, imgs_id_test = [], []
        file = open(self.train_test_split_file, "r")
        for line in file:
            img_id, is_train = line.split()
            if is_train == "1":
                imgs_id_train.append(img_id)
            elif is_train == "0":
                imgs_id_test.append(img_id)
        file.close()
        return imgs_id_train, imgs_id_test

    def _get_img_class(self):
        ''' 读取每张图片的 class 类别 '''
        imgs_class = []
        file = open(self.image_class_labels_file, 'r')
        for line in file:
            _, img_class = line.split()
            imgs_class.append(img_class)
        file.close()
        return imgs_class

    def _get_bondingbox(self):
        ''' 获取图像边框 '''
        bondingbox = []
        file = open(self.bounding_boxes_file)
        for line in file:
            _, x, y, w, h = line.split()
            x, y, w, h = float(x), float(y), float(w), float(h)
            bondingbox.append((x, y, x+w, y+h))
        file.close()
        return bondingbox

    def _get_img_attributes(self):
        ''' 根据图片 id 读取每张图片的属性,包括名字(路径)、类别和边框,并分别按照训练集和测试集划分 '''
        imgs_name_train, imgs_name_test, imgs_label_train, imgs_label_test, imgs_bbox_train, imgs_bbox_test = [], [], [], [], [], []
        imgs_id_train, imgs_id_test = self._get_img_id()  # 获取训练集和测试集的 img_id
        imgs_bbox = self._get_bondingbox()  # 获取所有图像的 bondingbox
        imgs_class = self._get_img_class()  # 获取所有图像类别标签,按照 img_id 存储
        file = open(self.images_file)
        for line in file:
            img_id, img_name = line.split()
            if img_id in imgs_id_train:
                img_id = int(img_id)
                imgs_name_train.append(img_name)
                imgs_label_train.append(imgs_class[img_id-1]) # 下标从 0 开始
                imgs_bbox_train.append(imgs_bbox[img_id-1])
            elif img_id in imgs_id_test:
                img_id = int(img_id)
                imgs_name_test.append(img_name)
                imgs_label_test.append(imgs_class[img_id-1])
                imgs_bbox_test.append(imgs_bbox[img_id-1])
        file.close()
        return imgs_name_train, imgs_name_test, imgs_label_train, imgs_label_test, imgs_bbox_train, imgs_bbox_test

    def _get_imgs(self, imgs_name, imgs_bbox):
        ''' 遍历每一张图片的路径,读取图片信息 '''
        data = []
        for i in range(len(imgs_name)):
            img_path = os.path.join(self.root, 'images', imgs_name[i])
            img = self._convert_and_resize(img_path, imgs_bbox[i])
            data.append(img)
        return data

    def _convert_and_resize(self, img_path, img_bbox):
        ''' 将不是 'RGB' 模式的图像变为 'RGB' 格式,更改图像大小 '''
        img = Image.open(img_path).resize((self.image_size, self.image_size), box=img_bbox)
        img.show()
        if img.mode == 'L':
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        print(type(img))
        return img

    def __getitem__(self, index):
        img, label = self.data[index], self.label[index]
        label = int(label) - 1  # 类别从 0 开始
        if self.target_transform is not None:
            label = self.target_transform(label)
        return img, label

    def __len__(self):
        return len(self.data)

if __name__ == "__main__":
    train_set = CUB200("./CUB_200_2011", train=True)  # 共 5994 张图片
    test_set = CUB200("./CUB_200_2011", train=False)  # 共 5794 张图片

最后

我还并没有使用该数据集进行训练网络,所以不知道是否有其他 bug。自己想实现的算法对数据集还需要进一步处理,这里先记录一下自己目前处理的成果。

2020/05/21 更新

这个写了很久了,本来想用这个数据集做单目标检测,但是最后还是没用到。主要是这个数据集的目标框的标注十分不准确,可能本身就是用来做分类的数据集,强行做目标检测是行不通的。(不得不吐槽一下,框标注的这么不准确,为啥还要有标注框信息啊。。。)

代码方面,由于我使用了标注框,对标注框内的图像做了截断然后 resize 了。如果标注框是准确的话,这没啥大问题,万万没想到标注框不准,所以代码就基本不能用了。我还是太年轻,太相信数据集的可靠性了。。。

当时搞了好几天最后发现不行,一生气,把电脑里关于这个数据集相关的都删掉了,。易怒症/xjj

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值