YOLOv2、v3使用K-means聚类计算anchor boxes的具体方法

k-means需要有数据,中心点个数是需要人为指定的,位置可以随机初始化,但是还需要度量到聚类中心的距离。这里怎么度量这个距离是很关键的。
距离度量如果使用标准的欧氏距离,大盒子会比小盒子产生更多的错误。例(100-95)^2=25, (5-2.5)^2=6.25。因此这里使用其他的距离度量公式。聚类的目的是anchor boxes和临近的ground truth有更大的IOU值,这和anchor box的尺寸没有直接关系。自定义的距离度量公式:
d(box,centroid)=1-IOU(box,centroid)
到聚类中心的距离越小越好,但IOU值是越大越好,所以使用 1 - IOU,这样就保证距离越小,IOU值越大

\Large{\textcircled{\small{1}}} 使用的聚类原始数据是只有标注框的检测数据集,YOLOv2、v3都会生成一个包含标注框位置和类别的TXT文件,其中每行都包含(x_j,y_j,w_j,h_j),j\in\{1,2,...,N\},即ground truth boxes相对于原图的坐标,(x_j,y_j)是框的中心点,(w_j,h_j)是框的宽和高,N是所有标注框的个数;
\Large{\textcircled{\small{2}}}首先给定k个聚类中心点(W_i,H_i),i\in\{1,2,...,k\},这里的W_i,H_i是anchor boxes的宽和高尺寸,由于anchor boxes位置不固定,所以没有(x,y)的坐标,只有宽和高;
\Large{\textcircled{\small{3}}}计算每个标注框和每个聚类中心点的距离 d=1-IOU(标注框,聚类中心),计算时每个标注框的中心点都与聚类中心重合,这样才能计算IOU值,即d=1-IOU\left [ (x_j,y_j,w_j,h_j),(x_j,y_j,W_i,H_i) \right ],j\in\{1,2,...,N\},i\in\{1,2,...,k\}。将标注框分配给“距离”最近的聚类中心;
\Large{\textcircled{\small{4}}}所有标注框分配完毕以后,对每个簇重新计算聚类中心点,计算方式为W_i^{'}=\frac{1}{N_i}\sum w_{i},H_i^{'}=\frac{1}{N_i}\sum h_{i}N_i是第i个簇的标注框个数,就是求该簇中所有标注框的宽和高的平均值。
重复第3、4步,直到聚类中心改变量很小。

代码实现主要是AlexeyAB/darknet中scripts/gen_anchors.py,这里根据yolov2,yolov3的版本不同进行部分修改。yolov2的配置文件yolov2.cfg需要的anchors是相对特征图的,值很小基本都小于13;yolov3的配置文件yolov3.cfg需要的3个anchors是相对于原图来说的,相对都比较大。还有输入图片的大小(32的倍数)对于输出也是有影响的。
例:
yolov2.cfg中[region] anchors =  0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828
yolov3.cfg中[region] anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326

 
  1. from os import listdir

  2. from os.path import isfile, join

  3. import argparse

  4. #import cv2

  5. import numpy as np

  6. import sys

  7. import os

  8. import shutil

  9. import random

  10. import math

  11.  
  12. def IOU(x,centroids):

  13. '''

  14. :param x: 某一个ground truth的w,h

  15. :param centroids: anchor的w,h的集合[(w,h),(),...],共k个

  16. :return: 单个ground truth box与所有k个anchor box的IoU值集合

  17. '''

  18. IoUs = []

  19. w, h = x # ground truth的w,h

  20. for centroid in centroids:

  21. c_w,c_h = centroid #anchor的w,h

  22. if c_w>=w and c_h>=h: #anchor包围ground truth

  23. iou = w*h/(c_w*c_h)

  24. elif c_w>=w and c_h<=h: #anchor宽矮

  25. iou = w*c_h/(w*h + (c_w-w)*c_h)

  26. elif c_w<=w and c_h>=h: #anchor瘦长

  27. iou = c_w*h/(w*h + c_w*(c_h-h))

  28. else: #ground truth包围anchor means both w,h are bigger than c_w and c_h respectively

  29. iou = (c_w*c_h)/(w*h)

  30. IoUs.append(iou) # will become (k,) shape

  31. return np.array(IoUs)

  32.  
  33. def avg_IOU(X,centroids):

  34. '''

  35. :param X: ground truth的w,h的集合[(w,h),(),...]

  36. :param centroids: anchor的w,h的集合[(w,h),(),...],共k个

  37. '''

  38. n,d = X.shape

  39. sum = 0.

  40. for i in range(X.shape[0]):

  41. sum+= max(IOU(X[i],centroids)) #返回一个ground truth与所有anchor的IoU中的最大值

  42. return sum/n #对所有ground truth求平均

  43.  
  44. def write_anchors_to_file(centroids,X,anchor_file,input_shape,yolo_version):

  45. '''

  46. :param centroids: anchor的w,h的集合[(w,h),(),...],共k个

  47. :param X: ground truth的w,h的集合[(w,h),(),...]

  48. :param anchor_file: anchor和平均IoU的输出路径

  49. '''

  50. f = open(anchor_file,'w')

  51.  
  52. anchors = centroids.copy()

  53. print(anchors.shape)

  54.  
  55. if yolo_version=='yolov2':

  56. for i in range(anchors.shape[0]):

  57. #yolo中对图片的缩放倍数为32倍,所以这里除以32,

  58. # 如果网络架构有改变,根据实际的缩放倍数来

  59. #求出anchor相对于缩放32倍以后的特征图的实际大小(yolov2)

  60. anchors[i][0]*=input_shape/32.

  61. anchors[i][1]*=input_shape/32.

  62. elif yolo_version=='yolov3':

  63. for i in range(anchors.shape[0]):

  64. #求出yolov3相对于原图的实际大小

  65. anchors[i][0]*=input_shape

  66. anchors[i][1]*=input_shape

  67. else:

  68. print("the yolo version is not right!")

  69. exit(-1)

  70.  
  71. widths = anchors[:,0]

  72. sorted_indices = np.argsort(widths)

  73.  
  74. print('Anchors = ', anchors[sorted_indices])

  75.  
  76. for i in sorted_indices[:-1]:

  77. f.write('%0.2f,%0.2f, '%(anchors[i,0],anchors[i,1]))

  78.  
  79. #there should not be comma after last anchor, that's why

  80. f.write('%0.2f,%0.2f\n'%(anchors[sorted_indices[-1:],0],anchors[sorted_indices[-1:],1]))

  81.  
  82. f.write('%f\n'%(avg_IOU(X,centroids)))

  83. print()

  84.  
  85. def kmeans(X,centroids,eps,anchor_file,input_shape,yolo_version):

  86.  
  87. N = X.shape[0] #ground truth的个数

  88. iterations = 0

  89. print("centroids.shape",centroids)

  90. k,dim = centroids.shape #anchor的个数k以及w,h两维,dim默认等于2

  91. prev_assignments = np.ones(N)*(-1) #对每个ground truth分配初始标签

  92. iter = 0

  93. old_D = np.zeros((N,k)) #初始化每个ground truth对每个anchor的IoU

  94.  
  95. while True:

  96. D = []

  97. iter+=1

  98. for i in range(N):

  99. d = 1 - IOU(X[i],centroids)

  100. D.append(d)

  101. D = np.array(D) # D.shape = (N,k) 得到每个ground truth对每个anchor的IoU

  102.  
  103. print("iter {}: dists = {}".format(iter,np.sum(np.abs(old_D-D)))) #计算每次迭代和前一次IoU的变化值

  104.  
  105. #assign samples to centroids

  106. assignments = np.argmin(D,axis=1) #将每个ground truth分配给距离d最小的anchor序号

  107.  
  108. if (assignments == prev_assignments).all() : #如果前一次分配的结果和这次的结果相同,就输出anchor以及平均IoU

  109. print("Centroids = ",centroids)

  110. write_anchors_to_file(centroids,X,anchor_file,input_shape,yolo_version)

  111. return

  112.  
  113. #calculate new centroids

  114. centroid_sums=np.zeros((k,dim),np.float) #初始化以便对每个簇的w,h求和

  115. for i in range(N):

  116. centroid_sums[assignments[i]]+=X[i] #将每个簇中的ground truth的w和h分别累加

  117. for j in range(k): #对簇中的w,h求平均

  118. centroids[j] = centroid_sums[j]/(np.sum(assignments==j)+1)

  119.  
  120. prev_assignments = assignments.copy()

  121. old_D = D.copy()

  122.  
  123. def main(argv):

  124. parser = argparse.ArgumentParser()

  125. parser.add_argument('-filelist', default = r'E:\BaiduNetdiskDownload\darknetHG8245\scripts\train.txt',

  126. help='path to filelist\n' )

  127. parser.add_argument('-output_dir', default = r'E:\BaiduNetdiskDownload\darknetHG8245', type = str,

  128. help='Output anchor directory\n' )

  129. parser.add_argument('-num_clusters', default = 0, type = int,

  130. help='number of clusters\n' )

  131. '''

  132. 需要注意的是yolov2输出的值比较小是相对特征图来说的,

  133. yolov3输出值较大是相对原图来说的,

  134. 所以yolov2和yolov3的输出是有区别的

  135. '''

  136. parser.add_argument('-yolo_version', default='yolov2', type=str,

  137. help='yolov2 or yolov3\n')

  138. parser.add_argument('-yolo_input_shape', default=416, type=int,

  139. help='input images shape,multiples of 32. etc. 416*416\n')

  140. args = parser.parse_args()

  141.  
  142. if not os.path.exists(args.output_dir):

  143. os.mkdir(args.output_dir)

  144.  
  145. f = open(args.filelist)

  146.  
  147. lines = [line.rstrip('\n') for line in f.readlines()]

  148.  
  149. annotation_dims = []

  150.  
  151. for line in lines:

  152. line = line.replace('JPEGImages','labels')

  153. line = line.replace('.jpg','.txt')

  154. line = line.replace('.png','.txt')

  155. print(line)

  156. f2 = open(line)

  157. for line in f2.readlines():

  158. line = line.rstrip('\n')

  159. w,h = line.split(' ')[3:]

  160. #print(w,h)

  161. annotation_dims.append((float(w),float(h)))

  162. annotation_dims = np.array(annotation_dims) #保存所有ground truth框的(w,h)

  163.  
  164. eps = 0.005

  165.  
  166. if args.num_clusters == 0:

  167. for num_clusters in range(1,11): #we make 1 through 10 clusters

  168. anchor_file = join( args.output_dir,'anchors%d.txt'%(num_clusters))

  169.  
  170. indices = [ random.randrange(annotation_dims.shape[0]) for i in range(num_clusters)]

  171. centroids = annotation_dims[indices]

  172. kmeans(annotation_dims,centroids,eps,anchor_file,args.yolo_input_shape,args.yolo_version)

  173. print('centroids.shape', centroids.shape)

  174. else:

  175. anchor_file = join( args.output_dir,'anchors%d.txt'%(args.num_clusters))

  176. indices = [ random.randrange(annotation_dims.shape[0]) for i in range(args.num_clusters)]

  177. centroids = annotation_dims[indices]

  178. kmeans(annotation_dims,centroids,eps,anchor_file,args.yolo_input_shape,args.yolo_version)

  179. print('centroids.shape', centroids.shape)

  180.  
  181. if __name__=="__main__":

  182. main(sys.argv)

这是其中的yolov3的结果 

 

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值