keras-yolo3项目之四:kmeans.py注释

kmeans.py文件主要是对训练数据集的标注框进行聚类,最终输出9个标注框,并将聚类好的9个标注框存储在yolo_anchors.txt文件中。

import numpy as np


class YOLO_Kmeans:
    """
    YOLO_Kmeans聚类模型
    """

    def __init__(self, cluster_number, filename):
        """
        初始化参数

        参数介绍:
        cluster_number:类别数量
        filename:文件名
        """
        self.cluster_number = cluster_number
        self.filename = "2012_train.txt"

    def iou(self, boxes, clusters):  # 1 box -> k clusters
        """
        计算面积的交并比
        boxes:标注框,格式:[[width, height],[width, height],...]
        clusters:从boxe中随机选择的标注框,默认9个
        """
        n = boxes.shape[0]    # 标注框的数量
        k = self.cluster_number    # 类别数量

        box_area = boxes[:, 0] * boxes[:, 1]    # width*height 计算标注框的面积,shape:(n,)
        box_area = box_area.repeat(k)   # 扩充数组元素,每个元素重复k次
        box_area = np.reshape(box_area, (n, k))   # shape:(n,k)

        cluster_area = clusters[:, 0] * clusters[:, 1]   #选中的标注框面积,shape:(9,)
        cluster_area = np.tile(cluster_area, [1, n])  #将面积整体(每9个为一组)重复n次,shape:(1,9*n)
        cluster_area = np.reshape(cluster_area, (n, k))

        box_w_matrix = np.reshape(boxes[:, 0].repeat(k), (n, k))    # 所有标注框的宽,shape:(n,k)
        cluster_w_matrix = np.reshape(np.tile(clusters[:, 0], (1, n)), (n, k))  #将选中的标注框的宽,shape:(n,k)
        min_w_matrix = np.minimum(cluster_w_matrix, box_w_matrix)   #取对应位置的最小值

        box_h_matrix = np.reshape(boxes[:, 1].repeat(k), (n, k))
        cluster_h_matrix = np.reshape(np.tile(clusters[:, 1], (1, n)), (n, k))
        min_h_matrix = np.minimum(cluster_h_matrix, box_h_matrix)
        inter_area = np.multiply(min_w_matrix, min_h_matrix)    # 对应位置元素的乘积

        result = inter_area / (box_area + cluster_area - inter_area)
        return result

    def avg_iou(self, boxes, clusters):
        """
        计算交并比的平均值
        """
        accuracy = np.mean([np.max(self.iou(boxes, clusters), axis=1)])
        return accuracy

    def kmeans(self, boxes, k, dist=np.median):
        """
        对标注框进行聚类,默认9类
        boxe:形如[[width, height],[width, height],...]
        k:最终聚成的类别数量,默认9
        """
        box_number = boxes.shape[0]    # 所有标注框的数量
        distances = np.empty((box_number, k))    #生成给定形状的数组
        last_nearest = np.zeros((box_number,))
        np.random.seed()
        # 随机选取k个标注框
        clusters = boxes[np.random.choice(
            box_number, k, replace=False)]  # init k clusters
        while True:

            distances = 1 - self.iou(boxes, clusters)

            current_nearest = np.argmin(distances, axis=1)
            if (last_nearest == current_nearest).all():
                break  # clusters won't change
            for cluster in range(k):
                clusters[cluster] = dist(  # update clusters
                    boxes[current_nearest == cluster], axis=0)

            last_nearest = current_nearest

        return clusters

    def result2txt(self, data):
        """
        聚类结果格式转换
        即,将聚类后的结果存储在文件yolo_anchors.txt中
        """
        f = open("yolo_anchors.txt", 'w')    # 存储聚类结果文件
        row = np.shape(data)[0]
        for i in range(row):
            if i == 0:
                x_y = "%d,%d" % (data[i][0], data[i][1])
            else:
                x_y = ", %d,%d" % (data[i][0], data[i][1])
            f.write(x_y)
        f.close()

    def txt2boxes(self):
        """
        从txt文件中提取标注框,每张图片中的标注框以空格为分隔符
        标注框格式:x_min,y_min,x_max,y_max,class_id;
        返回一个数组,包含每个标注框的宽和高,格式:[[width, height],[width, height],...]
        """
        f = open(self.filename, 'r')
        dataSet = []
        for line in f:
            infos = line.split(" ")
            length = len(infos)
            # 提取标注框
            for i in range(1, length):
                width = int(infos[i].split(",")[2]) - \
                    int(infos[i].split(",")[0])
                height = int(infos[i].split(",")[3]) - \
                    int(infos[i].split(",")[1])
                dataSet.append([width, height])
        result = np.array(dataSet)
        f.close()
        return result

    def txt2clusters(self):
        """
        对txt文件中提取出的标注框进行聚类操作
        """
        all_boxes = self.txt2boxes()    # 提取标注框
        result = self.kmeans(all_boxes, k=self.cluster_number)
        result = result[np.lexsort(result.T[0, None])]
        self.result2txt(result)
        print("K anchors:\n {}".format(result))
        print("Accuracy: {:.2f}%".format(
            self.avg_iou(all_boxes, result) * 100))


if __name__ == "__main__":
    """
    2012_train.txt文件内容格式如下:
    image_file_path box1 box2 ... boxN;
    box格式:x_min,y_min,x_max,y_max,class_id;
    path/to/img1.jpg 50,100,150,200,0 30,50,200,120,3
    path/to/img2.jpg 120,300,250,600,2
    """
    cluster_number = 9
    filename = "2012_train.txt"
    kmeans = YOLO_Kmeans(cluster_number, filename)
    kmeans.txt2clusters()
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

great-wind

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值