最近再回过头了看了下经典的 yolov3 算法, 里面的 anchors 聚类是用的 k-means . 下面是 python的一个实现, 大家可以参考讨论.
import os, sys
import numpy as np
import cv2
import math
if len(sys.argv) < 5:
print('usage:\n\t%s datacfg num_of_clusters width height')
sys.exit(-1)
datacfg = sys.argv[1]
num_of_clusters = int(sys.argv[2])
width = int(sys.argv[3])
height = int(sys.argv[4])
print()
train_file = None
with open(datacfg) as fr:
dc = fr.readlines()
for t in dc:
ts = t.split('=')
if ts[0].strip()== 'train' and len(ts) > 1:
train_file = ''.join(ts[1:]).strip()
break
if train_file is None:
print('can\'t find train in data file')
sys.exit(-1)
number_of_boxes = 0
r_wh_arr = []
with open(train_file) as fr:
dc = fr.readlines()
for i, t in enumerate(dc):
t = os.path.abspath(t.strip())
t1 = os.path.join(os.path.dirname(os.path.dirname(t)),'labels',\
os.path.basename(os.path.splitext(t)[0])+'.txt')
if not os.path.exists(t1):
print('can\'t find %s'%(t1))
continue
with open(t1) as lab_fr:
for lab_l in lab_fr.readlines():
lt = lab_l.split()
if len(lt) != 5:
print('wrong label:',t1)
continue
for ti in range(1,len(lt)):
lt[ti] = float(lt[ti])
lb_n,lb_x,lb_y,lb_w,lb_h= lt
if lb_x > 1 or lb_x <= 0 or \
lb_y > 1 or lb_y <= 0 or \
lb_w > 1 or lb_w <= 0 or \
lb_h > 1 or lb_h <= 0 :
print('wrong label:',t1)
continue
number_of_boxes += 1
r_wh_arr.append([lb_w*width,lb_h*height])
print("\r loaded \t image: %d \t box: %d"%(i+1, number_of_boxes),end='')
print("\n all loaded. ")
criteria = (cv2.TERM_CRITERIA_EPS +
cv2.TERM_CRITERIA_MAX_ITER, 10000, 0)# TERM_CRITERIA_MAX_ITER
flags = cv2.KMEANS_PP_CENTERS
datas = np.float32(np.array(r_wh_arr))
compactness, labels, centers = cv2.kmeans(datas, num_of_clusters, None, criteria, 10, flags)
avg_iou = 0
for di, dt in enumerate(datas):
c = centers[0]
min_dist = math.sqrt((dt[0]-c[0])**2 +(dt[1]-c[1])**2)
min_c = c
for c_i, c in enumerate(centers):
dist = math.sqrt((dt[0]-c[0])**2 +(dt[1]-c[1])**2)
if dist < min_dist:
min_dist = dist
min_c = c
# min_ci = min_dist
box_intersect = min(min_c[0],dt[0])*min(min_c[1],dt[1])
box_union = dt[0]*dt[1]+min_c[0]*min_c[1]-box_intersect
iou = box_intersect / box_union
if iou > 1 or iou < 0:
print('Wrong label: i = %d, box_w = %d, box_h = %d, anchor_w = %d, anchor_h = %d, iou = %f '%(di, dt[0], dt[1], min_c[0],min_c[1], iou) )
else:
avg_iou += iou
# print(min_c)
print('\navg IoU = %2.2f %%\n '%(100 * avg_iou / number_of_boxes))
print('anchors = ',end='')
for p_i, p in enumerate(sorted(centers.tolist(),key=lambda x: x[0]*x[1])):
if p_i != 0:
print(', ',end='')
print('%.4f,%.4f'%(p[0],p[1]),end='')
print('\n')
loaded image: 792 box: 939
all loaded.
avg IoU = 83.99 %
anchors = 21.2969,11.8091, 17.5623,17.6491, 16.9178,24.6314, 22.1751,20.8706, 32.5937,16.0246, 21.6062,30.9830, 17.6986,39.9335, 74.9313,148.7542, 150.5354,153.2271