YOLOX中decode 特征点解码过程可视化

YOLOX中decode 特征点解码过程可视化
该代码是特征宽高为20*20,batch_size=4,num_classes = 20进行解码可视化的过程。

import numpy as np
import matplotlib.pyplot as plt

def decode_for_vision(output):
    bs, hw = np.shape(output)[0], np.shape(output)[1:3]
    #  hw[0] * hw[1]  ------- 20,20
    output = np.reshape(output, [bs, hw[0] * hw[1], -1])
    #print(output)
    #output ------(4, 400, 23)
    grid_x, grid_y = np.meshgrid(np.arange(hw[1]), np.arange(hw[0]))
    #print(grid_x)
    grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))
    #grid ---------(1, 400, 2)
    #print(grid)
    box_xy = (output[..., :2] + grid)
    #box_xy.shape      (4, 400, 2)
    #output[..., :2]    (4, 400, 2)
    #grid               (1, 400, 2)
    box_wh = np.exp(output[..., 2:4])
    #output[..., 2:4].shape       (4, 400, 2)
    #box_wh                       (4, 400, 2)
    fig = plt.figure()
    ax = fig.add_subplot(121)
    plt.ylim(-2.22, hw[0] + 2.22)
    plt.xlim(-2.22, hw[1] + 2.22)
    plt.scatter(grid_x, grid_y)
    plt.scatter(0, 0, c='black')
    plt.scatter(1, 0, c='black')
    plt.scatter(2, 0, c='black')
    plt.scatter(box_xy[0, 0, 0], box_xy[0, 0, 1], c='r')
    plt.scatter(box_xy[0, 1, 0], box_xy[0, 1, 1], c='g')
    plt.scatter(box_xy[0, 2, 0], box_xy[0, 2, 1], c='b')
    plt.gca().invert_yaxis()

    pre_left = box_xy[..., 0] - box_wh[..., 0] / 2
    pre_top = box_xy[..., 1] - box_wh[..., 1] / 2
    rect1 = plt.Rectangle([pre_left[0, 0], pre_top[0, 0]], box_wh[0, 0, 0], box_wh[0, 0, 1], color="r", fill=False)
    rect2 = plt.Rectangle([pre_left[0, 1], pre_top[0, 1]], box_wh[0, 1, 0], box_wh[0, 1, 1], color="r", fill=False)
    rect3 = plt.Rectangle([pre_left[0, 2], pre_top[0, 2]], box_wh[0, 2, 0], box_wh[0, 2, 1], color="r", fill=False)
    ax.add_patch(rect1)
    ax.add_patch(rect2)
    ax.add_patch(rect3)
    plt.show()


if __name__ == '__main__':
    batch_size = 4
    num_classes = 20

    feat = np.concatenate(
        [
            np.random.uniform(-1, 1, [batch_size, 20, 20, 1]),
            np.random.uniform(1, 3, [batch_size, 20, 20, 2]),
            np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])
        ],
        axis=-1
    )
    # print(feat.shape)
    # s= np.random.uniform(-1, 1, [batch_size, 20, 20, 2])
    # s1 = np.random.uniform(1, 3, [batch_size, 20, 20, 2])
    # s2 = np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])
    # print(s2.shape)
    decode_for_vision(feat)



#grid_x
# [[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]]

#grid_y
# [[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
#  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]]

如下图所示
在这里插入图片描述

1. box_xy = (output[…, :2] + grid)

output 是一个形状为 (batch_size, height * width, 23) 的数组,其中 23 是通道的数量。每个位置包含了23个值,这些值通常包括:

  • 预测的边界框的坐标(中心点的 x 和 y 坐标)
  • 预测的边界框的宽度和高度
  • 每个类的置信度分数

在 output[…, :2] 中,output 的形状为 (batch_size, height * width, 23),… 表示选取所有的前面的维度,而 :2 表示选择最后一维的前两个值。这意味着我们在提取预测的边界框中心点的 x 和 y 坐标(通常是第一个和第二个值)。

2. grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))

在这个部分,grid_x 和 grid_y 是通过 np.meshgrid 创建的,它们的形状为 (height, width),表示网格中的每个位置的 x 和 y 坐标。

np.stack((grid_x, grid_y), 2)
np.stack 函数将 grid_x 和 grid_y 在新的维度(这里是第2个维度)上堆叠起来,因此生成的数组形状为 (height, width, 2),其中 2 表示堆叠的两个数组(grid_x 和 grid_y)。

np.reshape(…, (1, -1, 2))
np.reshape 函数将堆叠后的数组重新调整形状为 (1, -1, 2),具体如下:

  • 1 表示批量维度
  • -1 表示自动计算这一维度的大小,使总元素数保持不变
  • 2 表示每个位置的两个坐标(x 和 y)

这意味着我们将原始形状为 (height, width, 2) 的数组变换为形状为 (1, height * width, 2) 的数组。这是为了方便后续操作,使网格坐标与输出的形状匹配。

3. box_wh = np.exp(output[…, 2:4])
在这个代码中,output 的形状是 (batch_size, height * width, 23),其中 23 是每个预测位置上的特征数。特征的具体内容通常包括:

  • 预测的边界框的中心坐标 x 和 y(2 个值)。
  • 预测的边界框的宽度和高度(2 个值)。
  • 每个类的置信度分数(剩余的 19 个值,如果有 20 个类)。

因此,output[…, 2:4] 的意思是提取预测的边界框的宽度和高度。output[…, 2:4] 返回的是形状为 (batch_size, height * width, 2) 的数组,其中 2 代表宽度和高度两个值。

为什么使用 np.exp
模型在预测边界框的宽度和高度时,通常会预测其对数值。这是因为宽度和高度的值范围很大,直接预测这些值可能会使模型训练变得困难。因此,模型实际预测的是宽度和高度的对数值,这样可以将其转换回原始值:

box_wh = np.exp(output[..., 2:4])

具体过程
1.模型预测:

  • 模型预测的是边界框宽度和高度的对数值。
    2.指数转换:
  • 使用 np.exp 将对数值转换回实际的宽度和高度。
    3.得到实际宽度和高度:
  • 转换后的值代表预测的边界框的实际宽度和高度。

output[…, :2] 提取预测的边界框中心点的坐标。
output[…, 2:4] 提取预测的边界框的宽度和高度的对数值。
np.exp(output[…, 2:4]) 将对数值转换为实际的宽度和高度。

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

详细解释:

  1. 导入必要的库
import numpy as np
import matplotlib.pyplot as plt 
#导入 NumPy 用于数值计算,导入 Matplotlib 用于可视化。
  1. 解码并可视化检测器输出
def decode_for_vision(output):
    bs, hw = np.shape(output)[0], np.shape(output)[1:3]
    output = np.reshape(output, [bs, hw[0] * hw[1], -1])
    grid_x, grid_y = np.meshgrid(np.arange(hw[1]), np.arange(hw[0]))
    grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))
    box_xy = (output[..., :2] + grid)
    box_wh = np.exp(output[..., 2:4])
  • 获取批量大小 (bs) 和网格尺寸 (hw)。
  • 将 output 重塑为 (batch_size, height * width, -1) 的形状。
  • 使用 np.meshgrid 创建网格坐标 (grid_x, grid_y)。
  • 堆叠并重塑网格坐标以匹配 output 的形状。
  • 计算预测的边界框中心点 (box_xy) 和宽高 (box_wh)。
  1. 可视化网格和边界框
    fig = plt.figure()
    ax = fig.add_subplot(121)
    plt.ylim(-2.22, hw[0] + 2.22)
    plt.xlim(-2.22, hw[1] + 2.22)
    plt.scatter(grid_x, grid_y)
    plt.scatter(0, 0, c='black')
    plt.scatter(1, 0, c='black')
    plt.scatter(2, 0, c='black')
    plt.scatter(box_xy[0, 0, 0], box_xy[0, 0, 1], c='r')
    plt.scatter(box_xy[0, 1, 0], box_xy[0, 1, 1], c='g')
    plt.scatter(box_xy[0, 2, 0], box_xy[0, 2, 1], c='b')
    plt.gca().invert_yaxis()
  • 创建图像并设置坐标轴范围。
  • 绘制网格坐标点。
  • 绘制三个边界框中心点。
  1. 可视化边界框矩形
    pre_left = box_xy[..., 0] - box_wh[..., 0] / 2
    pre_top = box_xy[..., 1] - box_wh[..., 1] / 2
    rect1 = plt.Rectangle([pre_left[0, 0], pre_top[0, 0]], box_wh[0, 0, 0], box_wh[0, 0, 1], color="r", fill=False)
    rect2 = plt.Rectangle([pre_left[0, 1], pre_top[0, 1]], box_wh[0, 1, 0], box_wh[0, 1, 1], color="r", fill=False)
    rect3 = plt.Rectangle([pre_left[0, 2], pre_top[0, 2]], box_wh[0, 2, 0], box_wh[0, 2, 1], color="r", fill=False)
    ax.add_patch(rect1)
    ax.add_patch(rect2)
    ax.add_patch(rect3)
    plt.show()
  • 计算边界框左上角的坐标 (pre_left, pre_top)。
  • 创建矩形边界框并添加到图像中。
  • 显示图像。
  1. 主程序入口
if __name__ == '__main__':
    batch_size = 4
    num_classes = 20

    feat = np.concatenate(
        [
            np.random.uniform(-1, 1, [batch_size, 20, 20, 1]),
            np.random.uniform(1, 3, [batch_size, 20, 20, 2]),
            np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])
        ],
        axis=-1
    )
    decode_for_vision(feat)
  • 设置批量大小和类别数量。
  • 生成随机伪数据来模拟检测器输出。
  • 调用 decode_for_vision 函数解码并可视化这些数据。

总结
这个代码的主要目的是解码检测器输出数据并可视化网格和边界框,帮助理解检测器预测的边界框位置和大小。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值