yolov5之kmeans算法可视化

参考:(531条消息) 使用k-means聚类anchors_k-means anchors_太阳花的小绿豆的博客-CSDN博客

目录

 可视化:

 代码:


可视化:

 代码:

import numpy as np
from matplotlib import pyplot as plt
import os
import matplotlib.pyplot as plt
from PIL import Image
# txt文件夹路径
txts_path = "your_folder_path"
# 图片文件夹路径
jpgs_path = "your_folder_path"
np.random.seed(0)

colors = np.array(['blue', 'black', 'red', 'orange', 'gray', 'green', 'purple', 'pink', 'yellow'])


def plot_clusters(data, cls, clusters, title=""):
    if cls is None:
        c = [colors[0]] * data.shape[0]
    else:
        c = colors[cls].tolist()

    plt.scatter(data[:, 0], data[:, 1], c=c)
    for i, clus in enumerate(clusters):
        plt.scatter(clus[0], clus[1], c='gold', marker='*', s=150)
    plt.title(title)
    plt.savefig(title, '.png')
    plt.show()
    plt.close()


def distances(data, clusters):
    xy1 = data[:, None]  # [N,1,2]
    xy2 = clusters[None]  # [1,M,2]
    d = np.sum(np.power(xy2 - xy1, 2), axis=-1)
    return d


def k_means(data, k, dist=np.mean):
    """
    k-means methods
    Args:
        data: 需要聚类的data
        k: 簇数(聚成几类)
        dist: 更新簇坐标的方法
    """
    data_number = data.shape[0]
    last_nearest = np.zeros((data_number,))

    # init k clusters
    clusters = data[np.random.choice(data_number, k, replace=False)]
    print(f"random cluster: \n {clusters}")
    # plot
    plot_clusters(data, None, clusters, "random clusters")

    step = 0
    while True:
        d = distances(data, clusters)
        current_nearest = np.argmin(d, axis=1)

        # plot
        plot_clusters(data, current_nearest, clusters, f"step {step}")

        if (last_nearest == current_nearest).all():
            break  # clusters won't change
        for cluster in range(k):
            # update clusters
            clusters[cluster] = dist(data[current_nearest == cluster], axis=0)
        last_nearest = current_nearest
        step += 1

    return clusters


def main():
    # 存储w/h比例的列表
    ratios = []

    # 遍历文件夹中的每个txt文件
    for txt_file in os.listdir(txts_path):
        if txt_file.endswith(".txt"):
            txt_path = os.path.join(txts_path, txt_file)
            txtname = os.path.basename(txt_path)
            jpgname = txtname.replace(".txt", ".jpg")
            jpg_path = os.path.join(jpgs_path, jpgname)
            image = Image.open(jpg_path)
            width, height = image.size
            with open(txt_path, "r") as file:
                lines = file.readlines()
                zuobiao = []
                for line in lines:
                    category, x, y, w, h = line.strip().split(" ")
                    w = int(w)
                    h = int(h)
                    aw = w * width
                    # 因为yolov5我的imgsize参数设为了640,所以要做一个转换
                    sw = aw * 640 / width
                    zuobiao.append(sw)
                    ah = h * height
                    sh = ah * 640 / height
                    zuobiao.append(sh)
                ratios.append(zuobiao)

    x = []
    y = []
    for i in ratios:
        x.append(i[0])
        y.append(i[1])
    x = np.array(x)
    y = np.array(y)
    plt.scatter(x, y, c='blue')
    plt.title("initial data")
    plt.show()
    plt.close()

    clusters = k_means(np.concatenate([x[:, None], y[:, None]], axis=-1), k=9)
    print(f"k-means fluster: \n {clusters}")


if __name__ == '__main__':
    main()
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一鹿向晗99

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值