参考:(531条消息) 使用k-means聚类anchors_k-means anchors_太阳花的小绿豆的博客-CSDN博客
目录
可视化:
代码:
import numpy as np
from matplotlib import pyplot as plt
import os
import matplotlib.pyplot as plt
from PIL import Image
# txt文件夹路径
txts_path = "your_folder_path"
# 图片文件夹路径
jpgs_path = "your_folder_path"
np.random.seed(0)
colors = np.array(['blue', 'black', 'red', 'orange', 'gray', 'green', 'purple', 'pink', 'yellow'])
def plot_clusters(data, cls, clusters, title=""):
if cls is None:
c = [colors[0]] * data.shape[0]
else:
c = colors[cls].tolist()
plt.scatter(data[:, 0], data[:, 1], c=c)
for i, clus in enumerate(clusters):
plt.scatter(clus[0], clus[1], c='gold', marker='*', s=150)
plt.title(title)
plt.savefig(title, '.png')
plt.show()
plt.close()
def distances(data, clusters):
xy1 = data[:, None] # [N,1,2]
xy2 = clusters[None] # [1,M,2]
d = np.sum(np.power(xy2 - xy1, 2), axis=-1)
return d
def k_means(data, k, dist=np.mean):
"""
k-means methods
Args:
data: 需要聚类的data
k: 簇数(聚成几类)
dist: 更新簇坐标的方法
"""
data_number = data.shape[0]
last_nearest = np.zeros((data_number,))
# init k clusters
clusters = data[np.random.choice(data_number, k, replace=False)]
print(f"random cluster: \n {clusters}")
# plot
plot_clusters(data, None, clusters, "random clusters")
step = 0
while True:
d = distances(data, clusters)
current_nearest = np.argmin(d, axis=1)
# plot
plot_clusters(data, current_nearest, clusters, f"step {step}")
if (last_nearest == current_nearest).all():
break # clusters won't change
for cluster in range(k):
# update clusters
clusters[cluster] = dist(data[current_nearest == cluster], axis=0)
last_nearest = current_nearest
step += 1
return clusters
def main():
# 存储w/h比例的列表
ratios = []
# 遍历文件夹中的每个txt文件
for txt_file in os.listdir(txts_path):
if txt_file.endswith(".txt"):
txt_path = os.path.join(txts_path, txt_file)
txtname = os.path.basename(txt_path)
jpgname = txtname.replace(".txt", ".jpg")
jpg_path = os.path.join(jpgs_path, jpgname)
image = Image.open(jpg_path)
width, height = image.size
with open(txt_path, "r") as file:
lines = file.readlines()
zuobiao = []
for line in lines:
category, x, y, w, h = line.strip().split(" ")
w = int(w)
h = int(h)
aw = w * width
# 因为yolov5我的imgsize参数设为了640,所以要做一个转换
sw = aw * 640 / width
zuobiao.append(sw)
ah = h * height
sh = ah * 640 / height
zuobiao.append(sh)
ratios.append(zuobiao)
x = []
y = []
for i in ratios:
x.append(i[0])
y.append(i[1])
x = np.array(x)
y = np.array(y)
plt.scatter(x, y, c='blue')
plt.title("initial data")
plt.show()
plt.close()
clusters = k_means(np.concatenate([x[:, None], y[:, None]], axis=-1), k=9)
print(f"k-means fluster: \n {clusters}")
if __name__ == '__main__':
main()