【人工智能笔记】第十五节 通过K-means聚类算法,找到Yolo最佳anchors

K-means聚类算法,是通过随机选择N个聚类中心,找到所有点距离最近的中心,计算属于每个中心的点的平均值,用平均值更新中心位置。重复上述步骤不断更新中心位置,直到中心位置不变为止。

如图:

下面是实现方式:

素材便签文件如下:

<?xml version='1.0' encoding='utf-8'?>
<annotation>
  <folder>VOC2007</folder>
  <filename>domain1/100001.jpg</filename>
  <segmented>0</segmented>
  <size>
    <width>407</width>
    <height>247</height>
    <depth>3</depth>
  </size>
  <object>
    <name>lighter</name>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>240</xmin>
      <ymin>88</ymin>
      <xmax>306</xmax>
      <ymax>124</ymax>
    </bndbox>
  </object>
</annotation>

K-means聚类算法实现,kmeans.py:

import file_helper as FileHelper
import image_helper as ImageHelper
import argparse
import xml.dom.minidom as xml
import numpy as np

# 启动参数
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', default='labels')
args = parser.parse_args()

label_path = args.file_path
file_list = FileHelper.ReadFileList(label_path, r'.xml$')

# 网络图片大小
image_size = (416, 416)

# 缩放后的图片宽高
label_image_wh = []
# 遍历标签文件
for xml_path in file_list:
    # 读取xml文件
    # print('xml_path:', xml_path)
    xml_data = xml.parse(xml_path)
    collection = xml_data.documentElement
    # 读取标签图大小
    label_image_size = collection.getElementsByTagName("size")[0]
    label_image_width = int(label_image_size.getElementsByTagName("width")[0].childNodes[0].data)
    label_image_height = int(label_image_size.getElementsByTagName("height")[0].childNodes[0].data)
    # 对象列表
    object_list = collection.getElementsByTagName("object")

    # 读取目标框信息
    for object_item in object_list:
        # 原始点列表
        object_bndbox = object_item.getElementsByTagName("bndbox")[0]
        point_x1 = float(object_bndbox.getElementsByTagName("xmin")[0].childNodes[0].data)
        point_y1 = float(object_bndbox.getElementsByTagName("ymin")[0].childNodes[0].data)
        point_x2 = float(object_bndbox.getElementsByTagName("xmax")[0].childNodes[0].data)
        point_y2 = float(object_bndbox.getElementsByTagName("ymax")[0].childNodes[0].data)
        points = np.array([[point_x1, point_y1, point_x2, point_y2]])
        points = points.reshape((-1,2))
        points, _ = ImageHelper.opencvProportionalResizePoint((label_image_width, label_image_height), image_size, points=points)
        points = points.reshape((-1,4))
        label_image_wh.append([points[0][2]-points[0][0], points[0][3]-points[0][1]])
label_image_wh = np.array(label_image_wh, dtype=np.float32)

print('目标数量:', label_image_wh.shape)
# 初始化9个宽高
anchors_wh = []
for i in range(9):
    anchors_wh.append(label_image_wh[i])
anchors_wh = np.array(anchors_wh, dtype=np.float32)
print('初始化宽高:', anchors_wh)
# 遍历所有宽高
while True:
    anchors_wh_old = anchors_wh.copy()
    # 最近点集合
    anchors_points = [[] for _ in range(len(anchors_wh))]
    for label_image_wh_index in range(label_image_wh.shape[0]):
        min_index = 0
        min_d = None
        for anchors_index in range(len(anchors_wh)):
            d = np.sum(np.square(anchors_wh[anchors_index] - label_image_wh[label_image_wh_index]))
            if min_d is None or d < min_d:
                min_d = d
                min_index = anchors_index
        anchors_points[min_index].append(label_image_wh[label_image_wh_index])
    anchors_points = np.array(anchors_points)
    # 按平均值调整宽高
    for anchors_index in range(len(anchors_wh)):
        anchors_wh[anchors_index] = np.mean(anchors_points[anchors_index], axis=0)
    # print('调整后的宽高:', anchors_wh)
    if (anchors_wh_old == anchors_wh).all():
        break
# 排序
anchors_wh = anchors_wh.tolist()
print('kmeans的宽高:', np.array(anchors_wh))
anchors_wh.sort(key=lambda a: a[0]*a[1])
print('kmeans的宽高(排序):', np.array(anchors_wh))


执行脚本,可嵌套文件夹:

python kmeans.py --file_path 素材地址

K-means聚类算法参考资料:https://www.cnblogs.com/txx120/p/11487674.html

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

PPHT-H

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值