李沐动手学深度学习V2-目标检测数据集

一.目标检测数据集

1. 数据集介绍

目标检测领域没有像MNIST和Fashion-MNIST那样的小数据集,为了快速测试目标检测模型,收集并标记了一个小型数据集。 首先拍摄了一组香蕉的照片,并生成了1000张不同角度和大小的香蕉图像,然后在一些背景图片的随机位置上放一张香蕉的图像,最后在图片上为这些香蕉标记了边界框。

2. 下载数据集并定义数据集

该数据集包括一个的CSV文件,内含目标类别标签和位于左上角和右下角的真实边界框坐标。

import torch
import os
import pandas as pd
import torchvision
import d2l.torch
#@save
d2l.torch.DATA_HUB['banana-detection'] = (
    d2l.torch.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')
def read_data_bananas(is_train=True):
    data_dir = d2l.torch.download_extract('banana-detection')
    print(data_dir)
    csv_fpath = os.path.join(data_dir,'bananas_train'if is_train else 'bananas_val','label.csv')
    csv_file = pd.read_csv(csv_fpath)
    csv_file = csv_file.set_index('img_name')
    images,targets=[],[]
    for image_name,target in csv_file.iterrows():
        images.append(torchvision.io.read_image(path=os.path.join(data_dir,'bananas_train'if is_train else 'bananas_val','images',f'{image_name}')))
        targets.append(list(target))
    return images,torch.tensor(targets).unsqueeze(1)/256
images,targets = read_data_bananas()
targets[0]
targets.shape

通过使用read_data_bananas()函数读取图像和标签,以下BananasDataset类别将允许创建一个自定义Dataset实例来加载香蕉检测数据集。

"""一个用于加载香蕉检测数据集的自定义数据集"""
class BananasDataset(torch.utils.data.Dataset):
    def __init__(self,is_train):
        self.features,self.labels = read_data_bananas(is_train)
    def __getitem__(self, item):
        return (self.features[item].float(),self.labels[item])
    def __len__(self):
        return len(self.features)

定义load_data_bananas()函数,为训练集和测试集返回两个数据加载器实例,对于测试集,无须按随机顺序读取它。

"""加载香蕉检测数据集"""
def load_data_bananas(batch_size):
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size,shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size,shuffle=False)
    return train_iter,val_iter

3. 加载数据集

图像的小批量的形状为(批量大小、通道数、高度、宽度) ,标签的小批量的形状为(批量大小, 𝑚 ,5),其中 𝑚 是数据集的任何图像中边界框可能出现的最大数量
小批量计算虽然高效,但它要求每张图像含有相同数量的边界框,以便放在同一个批量中。 通常来说,图像可能拥有不同数量个边界框;因此,在达到 𝑚 之前,边界框少于 𝑚 的图像将被非法边界框填充。 这样,每个边界框的标签将被长度为5的数组表示。 数组中的第一个元素是边界框中对象的类别,其中-1表示用于填充的非法边界框。 数组的其余四个元素是边界框左上角和右下角的( 𝑥 , 𝑦 )坐标值(值域在0到1之间)。 对于香蕉数据集而言,由于每张图像上只有一个边界框,因此 𝑚=1 。

batch_size,edge_size = 32,256
train_iter,valid_iter = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape,batch[1].shape

4. 数据集图片展示

展示10幅带有真实边界框的图像,可以看到在所有这些图像中香蕉的旋转角度、大小和位置都有所不同。 当然这只是一个简单的人工数据集,实践中真实世界的数据集通常要复杂得多。

#permute()函数可以同时多次交换tensor的维度,如:b = a.permute(0,2 ,1) 将a的维度索引1和维度索引2交换位置
imgs = (batch[0][0:10].permute(0,2,3,1))/255 #除以255是为了对图片中每一个像素进行标准化
axes = d2l.torch.show_images(imgs,2,5,scale=2)
d2l.torch.bbox_to_rect()
for axe,label in zip(axes,batch[1][0:10]):
    d2l.torch.show_bboxes(axe,bboxes=[label[0][1:5]*256],colors=['w']) #label[0][1:5]*256乘256是因为加载数据集时bounding box边缘框除以256,256是图片的高和宽

图片展示

5.小结

  1. 收集的香蕉检测数据集可用于演示目标检测模型。
  2. 用于目标检测的数据加载与图像分类的数据加载类似。但是在目标检测中,标签还包含真实边界框的信息,在图像分类中没有真实边界框标签信息。

6. 全部代码

import torch
import os
import pandas as pd
import torchvision
import d2l.torch

#@save
d2l.torch.DATA_HUB['banana-detection'] = (
    d2l.torch.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')


def read_data_bananas(is_train=True):
    data_dir = d2l.torch.download_extract('banana-detection')
    print(data_dir)
    csv_fpath = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')
    csv_file = pd.read_csv(csv_fpath)
    csv_file = csv_file.set_index('img_name')
    images, targets = [], []
    for image_name, target in csv_file.iterrows():
        images.append(torchvision.io.read_image(
            path=os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{image_name}')))
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1) / 256


images, targets = read_data_bananas()
targets[0]
targets.shape
"""一个用于加载香蕉检测数据集的自定义数据集"""


class BananasDataset(torch.utils.data.Dataset):
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)

    def __getitem__(self, item):
        return (self.features[item].float(), self.labels[item])

    def __len__(self):
        return len(self.features)


"""加载香蕉检测数据集"""


def load_data_bananas(batch_size):
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True), batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False), batch_size, shuffle=False)
    return train_iter, val_iter


batch_size, edge_size = 32, 256
train_iter, valid_iter = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
#permute()函数可以同时多次交换tensor的维度,如:b = a.permute(0,2 ,1) 将a的维度索引1和维度索引2交换位置
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255  #除以255是为了对图片中每一个像素进行标准化
axes = d2l.torch.show_images(imgs, 2, 5, scale=2)
d2l.torch.bbox_to_rect()
for axe, label in zip(axes, batch[1][0:10]):
    d2l.torch.show_bboxes(axe, bboxes=[label[0][1:5] * 256],
                          colors=['w'])  #label[0][1:5]*256乘256是因为加载数据集时bounding box边缘框除以256,256是图片的高和宽
x_range = torch.arange(0, 2)
y_range = torch.arange(0, 4)
x, y = torch.meshgrid(y_range, x_range)
print(x)
print(y)
x, y = torch.meshgrid(x_range, y_range)
print(x)
print(y)
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值