【学习记录】YOLO的Anchor聚类(自行更新cfg文件)

2021-11-09

一、介绍

我们都知道yolov3针对训练数据通过k-means聚类的方法获得了合适的anchor boxes大小。这类代码网上也比较多,可以得到如以下图片的结果:
在这里插入图片描述
可以看到,这些代码得到的boxes是乱序的,所以需要我们根据这些boxes的面积大小重新排序,并填写入cfg文件中,还是比较麻烦的,于是我修改添加了一些代码,让anchor聚类完后得到boxes能自动更新到cfg文件中,减少一些繁琐的工作。

二、代码

# anchors.py
import glob
import xml.etree.ElementTree as ET
import os
import numpy as np
import configparser
from kmeans import kmeans, avg_iou

ANNOTATIONS_PATH = "Annotations"
CLUSTERS = 9
inputsize = 416
cfgname = 'YMAnet.cfg'

def load_dataset(path):
    dataset = []
    for xml_file in glob.glob("{}/*xml".format(path)):
        tree = ET.parse(xml_file)

        height = int(tree.findtext("./size/height"))
        width = int(tree.findtext("./size/width"))

        for obj in tree.iter("object"):
            xmin = int(obj.findtext("bndbox/xmin")) / width
            ymin = int(obj.findtext("bndbox/ymin")) / height
            xmax = int(obj.findtext("bndbox/xmax")) / width
            ymax = int(obj.findtext("bndbox/ymax")) / height

            xmin = np.float64(xmin)
            ymin = np.float64(ymin)
            xmax = np.float64(xmax)
            ymax = np.float64(ymax)
            if xmax == xmin or ymax == ymin:
                print(xml_file)
            dataset.append([xmax - xmin, ymax - ymin])
    return np.array(dataset)


if __name__ == '__main__':
    # print(__file__)
    data = load_dataset(ANNOTATIONS_PATH)
    out = kmeans(data, k=CLUSTERS)
    # clusters = [[10,13],[16,30],[33,23],[30,61],[62,45],[59,119],[116,90],[156,198],[373,326]]
    # out= np.array(clusters)/416.0
    out = np.around(out * inputsize ).astype(np.int)
    # print(out)
    print("Accuracy: {:.2f}%".format(avg_iou(data, out / inputsize ) * 100))
    # print("Boxes:\n {}-{}".format(out[:, 0], out[:, 1]))
    aa = np.multiply(out[:, 0], out[:, 1])
    # print("Area:\n {}".format(aa))
    sort_a = sorted(aa)
    print("Sort_Area:\n {}".format(sort_a))
    aa_sequence = sorted(range(len(aa)), key=lambda k: aa[k])
    # print("aa_sequence:\n {}".format(aa_sequence))
    out = out[aa_sequence, :]
    outformat = ''
    for i in range(CLUSTERS):
        if i == CLUSTERS-1:
            outformat = outformat + str(out[i, 0]) + ',' + str(out[i, 1])
        else:
            outformat = outformat + str(out[i, 0]) + ',' + str(out[i, 1]) + ',' + '  '


    print("Boxes:\n {}".format(outformat))

    ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
    print("Ratios:\n {}".format(sorted(ratios)))


    # 自动更新cfg文件中的anchors
    cwd = os.getcwd() + os.sep  # os.getcwd()用于返回当前工作目录
    cwd = os.path.abspath(os.path.join(os.path.dirname(cwd), os.path.pardir)) + os.sep
    filePath = cwd + 'cfg' + os.sep + cfgname
    string = 'anchors'
    # 开始查找
    count = 0
    f = open(filePath, "r")
    flist = f.readlines()
    c_row = []
    for line in flist:
        if string in line:
            c_row += [count]
            # print("第 " + str(count) + " 行已找到.")
            # print("该行内容: \n" + line)
        count += 1
    for i in c_row:
        flist[i] = 'anchors = ' + outformat + '\n'
    f.close()

    f = open(filePath, 'w+')
    f.writelines(flist)
    f.close()

三、结果

输出结果为排序好的Boxes及面积,不再需要自行计算面积大小及排序在这里插入图片描述
这是修改前的cfg文件中anchors的大小
修改前的cfg文件
这是修改后的cfg文件中anchor的大小,可以看到已经更新好了。
在这里插入图片描述

四、py文件放置位置

注意,anchors.py文件是放置在data文件夹中的,它修改的是cfg文件夹中指定的cfg文件。
cfg
在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值