深度学习 —— 个人学习笔记16(目标检测和边界框、目标检测数据集)

声明

  本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。

三十二、目标检测和边界框

import torch
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline

def show_images(imgs, titles=None):
        plt.imshow(imgs)
        backend_inline.set_matplotlib_formats('svg')
        plt.rcParams['figure.figsize'] = (6.5, 3.5)
        plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
        plt.title(titles)
        plt.show()


img = plt.imread('E:\\cat\\catdog.jpg')
show_images(img, titles='原图')

def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes


# bbox是边界框的英文缩写
dog1_bbox, dog2_bbox = [55.0, 265.0, 252.0, 590.0], [400.0, 18.0, 656.0, 590.0]
cat1_bbox, cat2_bbox = [231.0, 188.0, 443.0, 595.0], [650.0, 226.0, 905.0, 590.0]
boxes = torch.tensor((dog1_bbox, dog2_bbox, cat1_bbox, cat2_bbox))

print('测试函数正确性 : ', box_center_to_corner(box_corner_to_center(boxes)) == boxes)

def bbox_to_rect(bbox, color):
    # 将边界框(左上x,左上y,右下x,右下y)格式转换成 matplotlib 格式:
    # ((左上x,左上y),宽,高)
    return plt.Rectangle(
        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
        fill=False, edgecolor=color, linewidth=2)


fig = plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(cat1_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(cat2_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(dog2_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(dog1_bbox, 'blue'))

plt.axis('off')
plt.suptitle('标记后')
plt.show()


三十三、目标检测数据集

import os
import pandas as pd
import torch
import torchvision
from matplotlib import pyplot as plt

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        try:
            img = numpy(img)
        except:
            pass
        ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

def bbox_to_rect(bbox, color):
    return plt.Rectangle(
        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
        fill=False, edgecolor=color, linewidth=2)

def show_bboxes(axes, bboxes, labels=None, colors=None):
    def make_list(obj, default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj

    numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)
    labels = make_list(labels)
    colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])
    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = bbox_to_rect(numpy(bbox), color)
        axes.add_patch(rect)
        if labels and len(labels) > i:
            text_color = 'k' if color == 'w' else 'w'
            axes.text(rect.xy[0], rect.xy[1], labels[i],
                      va='center', ha='center', fontsize=9, color=text_color,
                      bbox=dict(facecolor=color, lw=0))

def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    data_dir = 'E:\\banana-detection'
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                             else 'bananas_val', 'label.csv')
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'bananas_train' if is_train else
                         'bananas_val', 'images', f'{img_name}')))
        # 这里的 target 包含(类别,左上角 x,左上角 y,右下角 x,右下角 y),
        # 其中所有图像都具有相同的香蕉类(索引为0)
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1) / 256

class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
              is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)

def load_data_bananas(batch_size):
    """ 加载香蕉检测数据集 """
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter


batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
print(f'(批量大小、通道数、高度、宽度) : {batch[0].shape}\n'
      f'(批量大小、数据集的任何图像中边界框可能出现的最大数量、5) : {batch[1].shape}')

imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    show_bboxes(ax, [label[0][1:5] * edge_size], colors=['r'])

plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.suptitle('数据集展示')
plt.show()



  文中部分知识参考:B 站 —— 跟李沐学AI;百度百科

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值