2021-07-22

使用 Transformer 进行目标检测(含代码)

用于对象检测的 Facebook 检测转换器 (DETR) 的完整指南。

英文版原文转载自https://medium.com/swlh/object-detection-with-transformers-437217a3d62e
在这里插入图片描述 DETR 的示例输出

介绍

DEtection TRansformer (DETR) 是 Facebook 研究团队开发的一种对象检测模型,它巧妙地利用了 Transformer 架构。

在这篇文章中,我将介绍 DETR 架构的内部工作原理,以帮助对其活动部件提供一些直观的了解。可以在此处找到本教程附带的 colab 笔记本

下面,我将解释一些架构,但如果您只想了解如何使用模型,请随时跳到编码部分。

架构

DETR 模型由一个预训练的CNN 主干(如 ResNet)组成,它产生一组低维特征集。

将这些特征格式化为一组特征并添加到位置编码中,然后将其馈入由编码器和解码器组成的变换器,其方式与原始变换器论文中描述的编码器-解码器变换器非常相似(http://arxiv.org/abs/1706.03762)。

接着将解码器的输出馈入固定数量的预测头它由预定义数量的前馈网络组成。这些预测头之一的每个输出都包含一个类预测,以及一个预测的边界框。

损失是通过计算二分匹配损失来计算的。
在这里插入图片描述DETR 架构;来自https://arxiv.org/pdf/2005.12872v3.pdf

该模型进行预定义数量的预测,并且每个预测都是并行计算的。

CNN主干

假设我们的高度的输入图像xᵢₘ 高度为H 0,宽度为W 0,并且使用三个输入通道。CNN 主干由一个预训练的CNN(通常是 ResNet)组成,我们用它来生成C 个具有宽度 W 和高度 H 的低维特征(实际上,我们设置C =2048,W=W₀/32 和 H=H₀/32 )。

这给我们留下了 C 个二维特征,并且由于我们将这些特征传递到transformer中,每个特征必须以允许编码器将每个特征作为序列处理的方式重新格式化。这是通过将特征矩阵展平为 H⋅W 向量,然后连接每个向量来完成的。
在这里插入图片描述卷积层的输出 → 图像特征

扁平化的卷积特征被添加到空间位置编码中,可以学习或预定义。

The Transformer

Transformer几乎与原始编码器-解码器架构相同。不同之处在于每个解码器层并行解码 N 个预定义数量的对象中的每一个。

该模型还学习了一组 N 个对象查询,这些查询是类似于编码器学习的位置编码。
在这里插入图片描述

对象查询

下图描绘了 N=20 的时候学习对象查询(称为预测槽)如何关注图像的不同区域。
在这里插入图片描述

“我们观察到,每个插槽都学会了专注于具有多种操作模式的特定区域和盒子尺寸。” — DETR 作者

理解对象查询的一种直观方式是想象每个对象查询都是一个人。每个人都可以通过注意力询问图像的某个区域。因此,一个对象查询将始终询问图像中心是什么,而另一个将始终询问左下角是什么,依此类推。

使用 PyTorch 实现简单的 DETR

import torch
import torch.nn as nn
from torchvision.models import resnet50

class SimpleDETR(nn.Module):
"""
Minimal Example of the Detection Transformer model with learned positional embedding
"""
 def __init__(self, num_classes, hidden_dim, num_heads,
             num_enc_layers, num_dec_layers):
    super(SimpleDETR, self).__init__()
    self.num_classes = num_classes
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.num_enc_layers = num_enc_layers
    self.num_dec_layers = num_dec_layers
    # CNN Backbone
    self.backbone = nn.Sequential(
         *list(resnet50(pretrained=True).children())[:-2])
    self.conv = nn.Conv2d(2048, hidden_dim, 1)
    # Transformer
    self.transformer = nn.Transformer(hidden_dim, num_heads,
         num_enc_layers, num_dec_layers)
    # Prediction Heads
    self.to_classes = nn.Linear(hidden_dim, num_classes+1)
    self.to_bbox = nn.Linear(hidden_dim, 4)
    # Positional Encodings
    self.object_query = nn.Parameter(torch.rand(100, hidden_dim))
    self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)
    self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
 def forward(self, X):
    X = self.backbone(X)
    h = self.conv(X)
    H, W = h.shape[-2:]
    pos_enc = torch.cat([
          self.col_embed[:W].unsqueeze(0).repeat(H,1,1),
          self.row_embed[:H].unsqueeze(1).repeat(1,W,1)],
       dim=-1).flatten(0,1).unsqueeze(1)
    h = self.transformer(pos_enc + h.flatten(2).permute(2,0,1),
    self.object_query.unsqueeze(1))
    class_pred = self.to_classes(h)
    bbox_pred = self.to_bbox(h).sigmoid()
    
    return class_pred, bbox_pred

二分匹配损失(可选)

令ŷ={ŷ ᵢ | i=1,…N}是一组预测,其中 ŷ=(ĉᵢ, bᵢ) 是由预测类(可以是空类)和边界框 bᵢ=( x̄ ᵢ , y̅ ᵢ , wᵢ ) 组成的元组, hᵢ ),其中条形符号表示端点之间的中点,wᵢ 和 hᵢ 分别是框的宽度和高度。

让 y 表示地面实况集。假设 y 和ŷ之间的损失为 L,每个 yᵢ 和ŷ ᵢ之间的损失为 L ᵢ。由于我们在集合级别上工作,因此损失 L 必须是置换不变的,这意味着无论我们如何对预测进行排序,我们都将获得相同的损失。因此,我们希望找到一个排列σ∈ Sₙ它映射预测的地面实况目标的指数的指数。在数学上,我们正在求解
最优二分匹配
最优二分匹配

计算 σ_hat 的过程称为寻找最优二分匹配。这可以使用匈牙利算法找到。但是为了找到最佳匹配,我们实际上需要定义一个损失函数来计算yᵢ 和ŷ _σ(i)之间的匹配成本。

回想一下,我们的预测由一个边界框和一个类组成。现在让我们假设类预测实际上是类集上的概率分布。那么第i个预测的总损失将是类预测产生的损失和边界框预测产生的损失。作者http://arxiv.org/abs/1906.05909将这种损失定义为边界框损失和类别预测概率的差异:
匹配损失匹配损失

其中 p-hatᵢ(cᵢ) 是来自 cᵢ 的 logits 的 argmax,Lbox 是边界框预测产生的损失。上面还指出,如果 cᵢ=∅,匹配损失为 0。

框损失计算为 L₁ 损失(位移)和预测边界框与地面实况边界框之间的广义交叉联合(GIOU) 损失的线性组合。

此外,如果你想象两个不相交的边界框,那么框错误将不会提供任何有意义的上下文(正如我们从下面框损失的定义中看到的那样)。
盒子损失
盒子损失

在上面的方程中,参数 λᵢₒᵤ 和 λ_L₁ 是标量超参数。请注意,这个总和也是由面积和距离产生的误差的组合。为什么这是有道理的?

是有意义的认为方程的如上面的总成本(i)与所述预测的B-hat_σ相关联,其中该价格区的错误是λᵢₒᵤ和价格距离误差是λ_L₁

现在让我们实际定义 GIOU 损失函数。它的定义如下:
GIOU 损失
GIOU 损失

由于我们是从给定数量的已知类别中预测类别,因此类别预测是一个分类问题,因此我们可以使用交叉熵损失来计算类别预测误差。我们将匈牙利损失函数定义为每 N 个预测损失的总和:
匈牙利损失函数
匈牙利损失函数

使用 DETR 进行对象检测

在这里,您可以了解如何使用 PyTorch 加载预训练的 DETR 模型以进行对象检测。

加载模型
首先导入将要使用的所需模块。

# Import required modules
import torch
from torchvision import transforms as T 
import requests # for loading images from web
from PIL import Image # for viewing images
import matplotlib.pyplot as plt

以下代码使用 ResNet50 作为 CNN 主干从 Torch Hub 加载预训练模型。对于其他主干,请参阅DETR github

detr = torch.hub.load('facebookresearch/detr',
                      'detr_resnet50',
                       pretrained=True)

加载图像
要从 Web 加载图像,我们使用 requests 库:

url = 'https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg' # Sample image
image = Image.open(requests.get(url, stream=True).raw) 
plt.imshow(image)
plt.show()

在这里插入图片描述

设置对象检测管道

要将图像输入模型,我们需要将图像从 PIL Image 转换为张量,这是通过使用 torchvision 的转换库来完成的。

transform = T.Compose([T.Resize(800),
                       T.ToTensor(),
                       T.Normalize([0.485, 0.456, 0.406],
                                  [0.229, 0.224, 0.225])])

上述变换调整图像大小,将图像从 PIL 图像转换,并使用均值标准差对图像进行归一化。其中 [0.485, 0.456, 0.406] 是每个颜色通道的平均值,[0.229, 0.224, 0.225] 是每个颜色通道的标准偏差。要查看更多图像转换,请参阅torchvision 文档

我们加载的模型是在COCO 数据集上预先训练的,包含 91 个类以及一个表示空类(无对象)的附加类。我们使用以下代码手动定义每个标签:

CLASSES = 
['N/A', 'Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic-Light', 'Fire-Hydrant', 'N/A', 'Stop-Sign', 'Parking Meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'N/A', 'Backpack', 'Umbrella', 'N/A', 'N/A', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports-Ball', 'Kite', 'Baseball Bat', 'Baseball Glove', 'Skateboard', 'Surfboard', 'Tennis Racket', 'Bottle', 'N/A', 'Wine Glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot-Dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted Plant', 'Bed', 'N/A', 'Dining Table', 'N/A','N/A', 'Toilet', 'N/A', 'TV', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell-Phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'N/A', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy-Bear', 'Hair-Dryer', 'Toothbrush']

如果我们想输出不同颜色的边界框,我们可以手动定义我们想要的RGB格式的颜色

COLORS = [ 
    [0.0000.4470.741][0.8500.3250.098][0.9290.6940.125][0.4940.1840.556][0.4660.6740.188][0.3010.7450.933]   
]

格式化输出
我们还需要重新格式化模型的输出。给定转换后的图像,模型将输出一个字典,其中包含 100 个预测类别的概率和 100 个预测边界框。
每个边界框都有这样的形式(x, y, w, h),其中 (x,y) 是边界框的中心(其中框是单位正方形 [0,1] ×[0,1]),w、h 是边界框的宽度和高度边界框。所以我们需要将边界框输出转换为初始和最终坐标,并重新缩放框以适应我们图像的实际大小。
以下函数返回边界框端点:

# 从模型输出 (x, y, w, h) 获取坐标 (x0, y0, x1, y0)
def get_box_coords ( boxes ): 
    x, y, w, h = box.unbind(1) 
    x0, y0 = (x - 0.5 * w), (y - 0.5 * h) 
    x1, y1 = (x + 0.5 * w) , (y + 0.5 * h) 
    box = [x0, y0, x1, y1] 
    return torch.stack(box, dim=1)

我们还需要缩放框的大小。以下函数为我们执行此操作:

# 将框从 [0,1]x[0,1] 缩放到 [0, width]x[0, height]
def scale_boxes(output_box, width, height):
    box_coords = get_box_coords(output_box)
    scale_tensor = torch.Tensor(
                 [width, height, width, height]).to(
                 torch.cuda.current_device())
    return box_coords * scale_tensor

现在我们需要一个函数来封装我们的对象检测管道。detect下面的函数为我们做这件事。

# 对象检测管道
def detect(im, model, transform):
    device = torch.cuda.current_device()
    width = im.size[0]
    height = im.size[1]
   
    #平均-STD正规化所述输入图像(批量大小:1)
    img = transform(im).unsqueeze(0)
    img = img.to(device)
    
    # 演示模型默认只支持长宽比在0.5到2之间的图片
    assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600,
  	# 通过模型进行传播
    outputs = model(img)
    # 只保留 0.7+ 置信度的预测
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.85
   
    # 从 [0 转换框; 1] 图像缩放
    bboxes_scaled = scale_boxes(outputs['pred_boxes'][0, keep], width, height)
    return probas[keep], bboxes_scaled

现在我们需要做的就是得到我们想要的输出:

probs, bboxes = detect(image, detr, transform)

绘制结果

现在我们有了检测到的对象,我们可以使用一个简单的函数来可视化它们

# 绘制预测边界框
def plot_results(pil_img, prob, boxes,labels=True):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    
    for prob, (x0, y0, x1, y1), color in zip(prob, boxes.tolist(),   COLORS * 100):
        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0,  
             fill=False, color=color, linewidth=2))
        cl = prob.argmax()
        text = f'{CLASSES[cl]}: {prob[cl]:0.2f}'
        if labels:
            ax.text(x0, y0, text, fontsize=15,
                bbox=dict(facecolor=color, alpha=0.75))
    plt.axis('off')
    plt.show()

现在我们可以可视化结果:

plot_results(image, probs, bboxes, labels=True)

在这里插入图片描述输出结果

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值