2021-11-09
一、介绍
我们都知道yolov3针对训练数据通过k-means聚类的方法获得了合适的anchor boxes大小。这类代码网上也比较多,可以得到如以下图片的结果:
可以看到,这些代码得到的boxes是乱序的,所以需要我们根据这些boxes的面积大小重新排序,并填写入cfg文件中,还是比较麻烦的,于是我修改添加了一些代码,让anchor聚类完后得到boxes能自动更新到cfg文件中,减少一些繁琐的工作。
二、代码
# anchors.py
import glob
import xml.etree.ElementTree as ET
import os
import numpy as np
import configparser
from kmeans import kmeans, avg_iou
ANNOTATIONS_PATH = "Annotations"
CLUSTERS = 9
inputsize = 416
cfgname = 'YMAnet.cfg'
def load_dataset(path):
dataset = []
for xml_file in glob.glob("{}/*xml".format(path)):
tree = ET.parse(xml_file)
height = int(tree.findtext("./size/height"))
width = int(tree.findtext("./size/width"))
for obj in tree.iter("object"):
xmin = int(obj.findtext("bndbox/xmin")) / width
ymin = int(obj.findtext("bndbox/ymin")) / height
xmax = int(obj.findtext("bndbox/xmax")) / width
ymax = int(obj.findtext("bndbox/ymax")) / height
xmin = np.float64(xmin)
ymin = np.float64(ymin)
xmax = np.float64(xmax)
ymax = np.float64(ymax)
if xmax == xmin or ymax == ymin:
print(xml_file)
dataset.append([xmax - xmin, ymax - ymin])
return np.array(dataset)
if __name__ == '__main__':
# print(__file__)
data = load_dataset(ANNOTATIONS_PATH)
out = kmeans(data, k=CLUSTERS)
# clusters = [[10,13],[16,30],[33,23],[30,61],[62,45],[59,119],[116,90],[156,198],[373,326]]
# out= np.array(clusters)/416.0
out = np.around(out * inputsize ).astype(np.int)
# print(out)
print("Accuracy: {:.2f}%".format(avg_iou(data, out / inputsize ) * 100))
# print("Boxes:\n {}-{}".format(out[:, 0], out[:, 1]))
aa = np.multiply(out[:, 0], out[:, 1])
# print("Area:\n {}".format(aa))
sort_a = sorted(aa)
print("Sort_Area:\n {}".format(sort_a))
aa_sequence = sorted(range(len(aa)), key=lambda k: aa[k])
# print("aa_sequence:\n {}".format(aa_sequence))
out = out[aa_sequence, :]
outformat = ''
for i in range(CLUSTERS):
if i == CLUSTERS-1:
outformat = outformat + str(out[i, 0]) + ',' + str(out[i, 1])
else:
outformat = outformat + str(out[i, 0]) + ',' + str(out[i, 1]) + ',' + ' '
print("Boxes:\n {}".format(outformat))
ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
print("Ratios:\n {}".format(sorted(ratios)))
# 自动更新cfg文件中的anchors
cwd = os.getcwd() + os.sep # os.getcwd()用于返回当前工作目录
cwd = os.path.abspath(os.path.join(os.path.dirname(cwd), os.path.pardir)) + os.sep
filePath = cwd + 'cfg' + os.sep + cfgname
string = 'anchors'
# 开始查找
count = 0
f = open(filePath, "r")
flist = f.readlines()
c_row = []
for line in flist:
if string in line:
c_row += [count]
# print("第 " + str(count) + " 行已找到.")
# print("该行内容: \n" + line)
count += 1
for i in c_row:
flist[i] = 'anchors = ' + outformat + '\n'
f.close()
f = open(filePath, 'w+')
f.writelines(flist)
f.close()
三、结果
输出结果为排序好的Boxes及面积,不再需要自行计算面积大小及排序
这是修改前的cfg文件中anchors的大小
这是修改后的cfg文件中anchor的大小,可以看到已经更新好了。
四、py文件放置位置
注意,anchors.py文件是放置在data文件夹中的,它修改的是cfg文件夹中指定的cfg文件。