YOLO K-means获取anchors大小代码详解

预备知识:应该了解yolo的基本操作,详见YOLO v1YOLO v2YOLO v3

首先应该了解yolo标签文件的格式,其格式为:图片的位置 框的4个坐标和1个类别ID (xmin,ymin,xmax,ymax,id) …。示例如下:

/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000012.jpg 156,97,351,270,6
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000017.jpg 185,62,279,199,14 90,78,403,336,12
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000023.jpg 9,230,245,500,1 230,220,334,500,1 2,1,117,369,14 3,2,243,462,14 225,1,334,486,14
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000026.jpg 90,125,337,212,6
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000032.jpg 104,78,375,183,0 133,88,197,123,0 195,180,213,229,14 26,189,44,238,14
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000033.jpg 9,107,499,263,0 421,200,482,226,0 325,188,411,223,0
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000034.jpg 116,167,360,400,18 141,153,333,229,18
/home/aift/CV/detect/yolo3-keras-master/VOCdevkit/VOC2007/JPEGImages/000035.jpg 1,96,191,361,14 218,98,465,318,14
……

实现代码如下:(注释已经很详细了)

import numpy as np

'''
k-means拿到数据里所有的目标框,得到所有的宽和高,在这里面随机取得9个随机中心,之后以9个点为中心得到9个族,不断计算其他点到中点的距离调整
每个点所归属的族和中心,直到9个中心不再变即可。这9个中心的x,y就是整个数据的9个合适的anchors==框的宽和高。
部分注释来源:https://www.jianshu.com/p/3fddf7c08a58

建议单步调试看运行结果
'''
class YOLO_Kmeans:

    def __init__(self, cluster_number, in_filename, out_filename):
        # 读取kmeans的中心数
        self.cluster_number = cluster_number
        # 标签文件的文件名
        self.in_filename = in_filename      # "2007_train.txt"
        self.out_filename = out_filename    # "2007_train_clusters.txt"

    # 这里注意对每个box算其与9个clusters的iou
    def iou(self, boxes, clusters):  # 1 box -> k clusters
        # boxes : 所有的[[width, height], [width, height], …… ]
        # clusters : 9个随机的中心点[width, height]
        n = boxes.shape[0]        # 6301
        k = self.cluster_number   # 9

        # 所有的boxes的面积
        box_area = boxes[:, 0] * boxes[:, 1]  # w * h  (6301,)
        # 将box_area的每个元素重复k次
        box_area = box_area.repeat(k)  # (56709,)
        box_area = np.reshape(box_area, (n, k))  # (6301, 9)

        # 计算9个中点的面积
        cluster_area = clusters[:, 0] * clusters[:, 1]  # w * h   (9,)
        # 对cluster_area进行复制n份
        cluster_area = np.tile(cluster_area, [1, n])  # (1, 56709)
        cluster_area = np.reshape(cluster_area, (n, k))  # (6301, 9)

        # 获取box和中心的交叉w的宽  w
        box_w_matrix = np.reshape(boxes[:, 0].repeat(k), (n, k))                 # (6301, 9)
        cluster_w_matrix = np.reshape(np.tile(clusters[:, 0], (1, n)), (n, k))   # (6301, 9)
        min_w_matrix = np.minimum(cluster_w_matrix, box_w_matrix)                # (6301, 9)

        # 获取box和中心的交叉h的高  h
        box_h_matrix = np.reshape(boxes[:, 1].repeat(k), (n, k))
        cluster_h_matrix = np.reshape(np.tile(clusters[:, 1], (1, n)), (n, k))
        min_h_matrix = np.minimum(cluster_h_matrix, box_h_matrix)

        # 交叉点的面积
        inter_area = np.multiply(min_w_matrix, min_h_matrix)  # (6301, 9)  6301个boxes与9个clusters的交集面积
        # 9个交叉点和所有的boxes的iou值
        result = inter_area / (box_area + cluster_area - inter_area)

        return result  # (6301, 9)  6301个boxes与9个clusters的iou

    def avg_iou(self, boxes, clusters):
        #  计算9个中点与所有的boxes总的iou,n个点的平均iou
        accuracy = np.mean([np.max(self.iou(boxes, clusters), axis=1)])
        return accuracy

    def kmeans(self, boxes, k, dist=np.median):
        # boxes = [[宽, 高], [宽, 高], …… ]
        # k 中心点数
        # np.median 求众数
        box_number = boxes.shape[0]  # VOC2007: 6301
        distances = np.empty((box_number, k))  # (6301, 9)
        # 每个boxes属于哪个cluster
        last_nearest = np.zeros((box_number,))  # (6301,)
        np.random.seed()

        # 从所有的boxe中选取9个随机中心点(w, h)   https://www.cnblogs.com/cloud-ken/p/9931273.html
        clusters = boxes[np.random.choice(box_number, k, replace=False)]  # init k clusters   (9, 2)

        while True:
            # 计算所有的boxes和clusters的值(n,k)
            distances = 1 - self.iou(boxes, clusters)  # (6301, 9)  6301个boxes与9个clusters的iou距离(1 - iou)   (6301, 9)
            # 选取iou值最小的点(n,)
            current_nearest = np.argmin(distances, axis=1)  # (6301,)
            # 中心点未改变,跳出
            if (last_nearest == current_nearest).all():
                break  # clusters won't change
            # 计算每个群组的中心或者众数
            for cluster in range(k):
                clusters[cluster] = dist(boxes[current_nearest == cluster], axis=0)  # update clusters
            # 改变中心点(每个boxes属于哪个cluster)
            last_nearest = current_nearest

        return clusters

    def result2txt(self, result):
        # 把9个中心点,写入txt文件
        f = open(self.out_filename, 'w')
        row = np.shape(result)[0]  # 9
        for i in range(row):
            if i == 0:
                x_y = "%d,%d" % (result[i][0], result[i][1])
            else:
                x_y = ", %d,%d" % (result[i][0], result[i][1])
            f.write(x_y)
        f.close()

    def txt2boxes(self):
        # 打开文件
        f = open(self.in_filename, 'r')
        dataSet = []  # list
        # 读取文件
        for line in f:
            infos = line.split(" ")
            length = len(infos)
            # infons[0] 为图片的名称
            for i in range(1, length):
                # 获取文件的宽和高  这里说明Bbox的标注信息是左上角和右下角标注
                width = int(infos[i].split(",")[2]) - \
                        int(infos[i].split(",")[0])
                height = int(infos[i].split(",")[3]) - \
                         int(infos[i].split(",")[1])
                dataSet.append([width, height])   # [(w, h), (w, h), ……]
        result = np.array(dataSet)
        f.close()
        return result

    def txt2clusters(self):
        # 获取所有的文件目标的宽和高,width, height
        all_boxes = self.txt2boxes()  # ndarray: [(w, h), (w, h), ……]
        # result 9个聚类中心点
        result = self.kmeans(all_boxes, k=self.cluster_number)  # (9, 2)
        # 按第一列顺序排序(w)  https://www.cnblogs.com/hellcat/p/6874175.html#_label0_0
        result = result[np.lexsort(result.T[0, None])]  # (9, 2)   result.T[0, None]: (1, 9)
        # 把结果写入txt文件
        self.result2txt(result)
        print("K anchors:\n {}".format(result))
        #  计算9个中点与所有的boxes总的iou,n个点的平均iou
        print("Accuracy: {:.2f}%".format(
            self.avg_iou(all_boxes, result) * 100))


if __name__ == "__main__":
    cluster_number = 9                         # clusters数量
    in_filename = '2007_train.txt'             # 输入标签文件
    out_filename = '2007_train_clusters.txt'   # 输出clusters文件
    kmeans = YOLO_Kmeans(cluster_number, in_filename, out_filename)
    kmeans.txt2clusters()

done~

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值