yolo2的先验框的生成代码(解析)

下面是yolo2的先验框的生成代码,来自https://blog.csdn.net/weixin_44791964/article/details/102687531

因为理解不深,所以看下面代码后想了好久才想明白……我整理了一下,首先要明白几点。
①yolo2输出13×13的网格,每个网格有5个先验框。
②5个先验框共享一个中点的,也就是每个网格的中点,所以生成先验框就是要求出先验框的长和宽。
③下面的聚合函数思想就是将所有的真实框按照iou分成5堆,按照iou聚类的意思,就是一堆的真实框的iou比较高。怎么判断两个框之间的iou是否比较高的,就是将所有框和随机选出来的聚类中心作比较,如果两个框和聚类中心相比iou都比较大,那么这两个真实框的iou也是比较大,也就是一类的。

box包含着真实框的长宽信息。
box的形式为[[w1,h2],[w2,h2],……].也就是形状为[n,2],n表示真实框的个数,2表示二维,长和宽。
k表示聚类个数,这里也就是5.

def cas_iou(box, cluster):#计算交互比
    x = np.minimum(cluster[:, 0], box[0])
    y = np.minimum(cluster[:, 1], box[1])

    intersection = x * y
    area1 = box[0] * box[1]

    area2 = cluster[:, 0] * cluster[:, 1]
    iou = intersection / (area1 + area2 - intersection)

    return iou


def avg_iou(box, cluster):#平均交互比
    return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0])])


def kmeans(box, k):
    # 取出一共有多少框
    row = box.shape[0]

    # 每个框各个点的位置,用于存放每个框和聚类中心的距离
    distance = np.empty((row, k))

    # 最后的聚类位置
    last_clu = np.zeros((row,))

    np.random.seed()

    # 随机选5个当聚类中心
    cluster = box[np.random.choice(row, k, replace=False)]
    # cluster = random.sample(row, k)
    while True:
        # 计算每一行距离五个点的iou情况。存放距离
        for i in range(row):
            distance[i] = 1 - cas_iou(box[i], cluster)

        # 每行取出最小点,也就是针对每个真实框和哪个先验框最接近,就将该框聚到哪一类
        near = np.argmin(distance, axis=1)

        if (last_clu == near).all():
            break

        # 求每一个类的中位点,取出离每个聚类中心最近的真实框,并对真实框的数据求中位数,作为先验框长和宽的输出。
        for j in range(k):
            cluster[j] = np.median(
                box[near == j], axis=0)

        last_clu = near

    return cluster

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值