目标检测 (二):用于快速测试模型的小数据集--红绿灯数据集

默认加载以下模块:

%matplotlib inline
import os
import json
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import math
from IPython import display

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import models
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import visdom
# from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

说明

在目标检测领域并没有类似MNIST或Fashion-MNIST那样的小数据集。为了快速测试模型,这里使用一个小数据集–红绿灯数据集。该数据集在img和json文件夹中分别存有图片与标注,文件名为0.jpg, 1.jpg… 0.json, 1.json…

标签的形状是(批量大小, m , 5),其中 m 等于数据集中单个图像最多含有的边界框个数。小批量计算虽然高效,但它要求每张图像含有相同数量的边界框,以便放在同一个批量中。由于每张图像含有的边界框个数可能不同,我们为边界框个数小于 m 的图像填充非法边界框,直到每张图像均含有 m 个边界框。这样,我们就可以每次读取小批量的图像了。图像中每个边界框的标签由长度为5的数组表示。数组中第一个元素是边界框所含目标的类别。当值为-1时,该边界框为填充用的非法边界框。数组的剩余4个元素分别表示边界框左上角的 x 和 y 轴坐标以及右下角的 x 和 y 轴坐标(值域在0到1之间)。这里的数据集中每个图像只有一个边界框,因此 m=1

读取数据集

class TrafficLightData(Dataset):
    """
    红绿灯数据集
    原始图像大小为240*424
    """
    def __init__(self, path, transforms=None):
        self.annotation_path = path + '/json'
        self.img_path = path + '/img'
        self.transforms = transforms
        
    def __len__(self):
        return len(os.listdir(self.img_path))
    
    def __getitem__(self, index):
        annotation = json.load(open(self.annotation_path + '/' + str(index) + '.json'))
        img = Image.open(self.img_path + '/' + str(index) + '.jpg')
        
        w, h = img.size
            
        # 红灯类别为0,绿灯类别为1
        cls = 0 if annotation['class'] == 'red' else 1
        
        label = np.array([cls] + annotation['loc'], dtype='float32')
        label[1:] /= np.array([w, h, w, h]) # bbox坐标转化成相对坐标形式
        
        label = torch.tensor(label)
        
        if self.transforms:
            img = self.transforms(img)
            
        sample = {
            "label": label, # shape: (1, 5) [class, xmin, ymin, xmax, ymax]
            "image": img    # shape: (3, *image_size)
        }
            
        return sample
  • 数据增广策略:对于训练集中的每张图像,采用随机裁剪,并要求裁剪出的图像至少覆盖每个目标95%的区域。由于裁剪是随机的,这个要求不一定总被满足。我们设定最多尝试200次随机裁剪:如果都不符合要求则不裁剪图像。为保证输出结果的确定性,我们不随机裁剪测试数据集中的图像。我们也无须按随机顺序读取测试数据集。
    由于目标检测中的数据增广还要改变标签,暂时等之后再实现
train_dataset = TrafficLightData('D:/Download/Dataset/traffic_light/train',
                transforms=transforms.Compose([
                    transforms.Resize(240), # 将图像最短边缩至240,宽高比例不变
                    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0),
                    transforms.ToTensor(), # 将PIL图像转为Tensor,并且进行归一化
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
                ]))
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)#, num_workers=2) # num_workers表示使用几个线程来加载数据 我的电脑加了这个参数就报错,可能不支持多线程操作

test_dataset = TrafficLightData('D:/Download/Dataset/traffic_light/dev',
                transforms=transforms.Compose([
                    transforms.Resize(240), # 将图像最短边缩至240,宽高比例不变
                    transforms.ToTensor(), # 将PIL图像转为Tensor,并且进行归一化
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
                ]))

batch_size = 4
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)#, num_workers=2) # num_workers表示使用几个线程来加载数据 我的电脑加了这个参数就报错,可能不支持多线程操作

可视化训练集图片

def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize
    
set_figsize()

def bbox_to_rect(bbox, color): 
    # 将边界框(左上x, 左上y, 右下x, 右下y)格式转换成matplotlib格式:
    # ((左上x, 左上y), 宽, 高)
    return plt.Rectangle(
        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
        fill=False, edgecolor=color, linewidth=2)

def show_bboxes(axes, bboxes, labels=None, colors=['b', 'g', 'r', 'm', 'c']):
    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = bbox_to_rect(bbox.detach().cpu().numpy(), color)
        axes.add_patch(rect) # 在原图片中画上锚框
        if labels and len(labels) > i:
            text_color = 'k' if color == 'w' else 'w'
            # 在锚框左上角标做标注
            axes.text(rect.xy[0], rect.xy[1], labels[i],
                      va='center', ha='center', fontsize=6, color=text_color,
                      bbox=dict(facecolor=color, lw=0))
data_iter = iter(train_dataloader)
sample = data_iter.next()
images = sample['image'] # (bn, c, h, w)
labels = sample['label'] # (bn, 5)
labels[0][1:].shape

# unnormalize  img*std+mean
means = torch.tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) # (1, c, 1, 1)
stds = torch.tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) # (1, c, 1, 1)
images = images * stds + means

# 可视化1个batch的训练集图像
num_cols = 2 # 一行4张图片
num_rows = math.ceil(batch_size // num_cols)
scale = 5 # 图片大小
figsize = (num_cols * scale, num_rows * scale)
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize) 
axes = axes.flatten()
h, w = images.shape[-2:]
bbox_scale = torch.Tensor([w, h, w, h]) # 因为锚框坐标值是0`1,因此这里乘上宽高来恢复其原来的坐标位置
for i, image in enumerate(images):
    axes[i].imshow(image.permute(1, 2, 0)) # imshow要求的图片形状是 (h, w, c)
    show_bboxes(axes[i], [labels[i][1:] * bbox_scale])
    axes[i].axes.get_xaxis().set_visible(False) # 不显示x,y的坐标值
    axes[i].axes.get_yaxis().set_visible(False)
# tensorboard可视化
# writer = SummaryWriter('./runs/traffic_lights')
# writer.add_figure('Train_imgs', fig, 0)
# writer.close()
# 打印选中训练集图像的类别
print(sample['label'][:, 0]) 

output:

tensor([0., 1., 0., 1.])

在这里插入图片描述

参考文献

  • 4
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值