yolov3 anchors k-means脚本

最近再回过头了看了下经典的 yolov3 算法, 里面的 anchors 聚类是用的 k-means . 下面是 python的一个实现, 大家可以参考讨论.

import os, sys 
import numpy as np 
import cv2 
import math 

if len(sys.argv) < 5:
    print('usage:\n\t%s datacfg num_of_clusters width height')
    sys.exit(-1)

datacfg = sys.argv[1]
num_of_clusters = int(sys.argv[2])
width = int(sys.argv[3])
height = int(sys.argv[4])

print()
train_file = None 
with open(datacfg) as fr:
    dc = fr.readlines()
    for t in dc:
        ts = t.split('=')

        if ts[0].strip()== 'train' and len(ts) > 1:
            train_file = ''.join(ts[1:]).strip()
            break
    if train_file is None:
        print('can\'t find train in data file')
        sys.exit(-1)

number_of_boxes = 0 
r_wh_arr = []

with open(train_file) as fr:
    dc = fr.readlines()
    for i, t in enumerate(dc):
        t = os.path.abspath(t.strip())
        t1 = os.path.join(os.path.dirname(os.path.dirname(t)),'labels',\
            os.path.basename(os.path.splitext(t)[0])+'.txt')
        if not os.path.exists(t1):
            print('can\'t find %s'%(t1))
            continue
        with open(t1) as lab_fr:
            for lab_l in lab_fr.readlines():
                lt = lab_l.split()
                if len(lt) != 5:
                    print('wrong label:',t1)
                    continue
                for ti in range(1,len(lt)):
                    lt[ti] = float(lt[ti])
                lb_n,lb_x,lb_y,lb_w,lb_h= lt 

                if  lb_x > 1 or lb_x <= 0 or \
                    lb_y > 1 or lb_y <= 0 or \
                    lb_w > 1 or lb_w <= 0 or \
                    lb_h > 1 or lb_h <= 0 :
                    print('wrong label:',t1)
                    continue
                number_of_boxes += 1 
                r_wh_arr.append([lb_w*width,lb_h*height])
                print("\r loaded \t image: %d \t box: %d"%(i+1, number_of_boxes),end='')

    print("\n all loaded. ")
    
criteria = (cv2.TERM_CRITERIA_EPS + 
            cv2.TERM_CRITERIA_MAX_ITER, 10000, 0)#  TERM_CRITERIA_MAX_ITER

flags = cv2.KMEANS_PP_CENTERS
datas = np.float32(np.array(r_wh_arr))
compactness, labels, centers = cv2.kmeans(datas, num_of_clusters, None, criteria, 10, flags) 

avg_iou = 0
for di, dt in enumerate(datas):
    c = centers[0]
    min_dist =  math.sqrt((dt[0]-c[0])**2 +(dt[1]-c[1])**2)
    min_c = c
    for c_i, c in enumerate(centers):
        dist = math.sqrt((dt[0]-c[0])**2 +(dt[1]-c[1])**2)
        if dist < min_dist: 
            min_dist = dist 
            min_c = c
            # min_ci = min_dist
    
    box_intersect = min(min_c[0],dt[0])*min(min_c[1],dt[1]) 
    box_union = dt[0]*dt[1]+min_c[0]*min_c[1]-box_intersect 
    iou = box_intersect / box_union
    if iou > 1 or iou < 0:
        print('Wrong label: i = %d, box_w = %d, box_h = %d, anchor_w = %d, anchor_h = %d, iou = %f '%(di, dt[0], dt[1], min_c[0],min_c[1], iou) )
    else:
        avg_iou += iou
    # print(min_c)

print('\navg IoU = %2.2f %%\n '%(100 * avg_iou / number_of_boxes))

print('anchors = ',end='')
for p_i, p in enumerate(sorted(centers.tolist(),key=lambda x: x[0]*x[1])):
    if p_i != 0:
        print(', ',end='')
    print('%.4f,%.4f'%(p[0],p[1]),end='')

print('\n')

 loaded          image: 792      box: 939
 all loaded. 

avg IoU = 83.99 %
 
anchors = 21.2969,11.8091, 17.5623,17.6491, 16.9178,24.6314, 22.1751,20.8706, 32.5937,16.0246, 21.6062,30.9830, 17.6986,39.9335, 74.9313,148.7542, 150.5354,153.2271

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值