k-means 计算voc2012数据集的检测anchors的长宽聚类结果

简介

因为one-stage 的检测模型一般要指定检测目标的长宽比和范围,类似ssd,yolov3等,那么,就有统计样本长宽的聚类需求,分析样本数据,根据样本数据的情况,对长宽比和范围进行设置,代码比较简单,下面就是代码。

代码

代码主要包括voc的xml读取部分,产生长宽比数据部分,以及进行kmeans计算部分,具体代码如下:

#!/usr/bin/env python
# -*- coding: utf8 -*-
import sys
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement
from lxml import etree
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

XML_EXT = '.xml'
ENCODE_METHOD = 'utf-8'

#pascalVocReader readers the voc xml files parse it
class PascalVocReader:
    """
    this class will be used to get transfered width and height from voc xml files
    """
    def __init__(self, filepath,width,height):
        # shapes type:
        # [labbel, [(x1,y1), (x2,y2), (x3,y3), (x4,y4)], color, color, difficult]
        self.shapes = []
        self.filepath = filepath
        self.verified = False
        self.width=width
        self.height=height

        try:
            self.parseXML()
        except:
            pass

    def getShapes(self):
        return self.shapes

    def addShape(self, bndbox, width,height):
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)
        width_trans = (xmax - xmin)/width*self.width
        height_trans = (ymax-ymin)/height *self.height
        points = [width_trans,height_trans]
        self.shapes.append((points))

    def parseXML(self):
        assert self.filepath.endswith(XML_EXT), "Unsupport file format"
        parser = etree.XMLParser(encoding=ENCODE_METHOD)
        xmltree = ElementTree.parse(self.filepath, parser=parser).getroot()
        pic_size = xmltree.find('size')
        size = (int(pic_size.find('width').text),int(pic_size.find('height').text))
        for object_iter in xmltree.findall('object'):
            bndbox = object_iter.find("bndbox")
            self.addShape(bndbox, *size)
        return True

class create_w_h_txt:
    def __init__(self,vocxml_path,txt_path):
        self.voc_path = vocxml_path
        self.txt_path = txt_path
    def _gether_w_h(self):
        pass
    def _write_to_txt(self):
        pass
    def process_file(self):
        file_w = open(self.txt_path,'a')
       # print (self.txt_path)
        for file in os.listdir(self.voc_path):
            file_path = os.path.join(self.voc_path, file)
            xml_parse = PascalVocReader(file_path,304,304)
            data = xml_parse.getShapes()
            for w,h in data :
                txtstr = str(w)+' '+str(h)+'\n'
                #print (txtstr)
                file_w.write(txtstr)
        file_w.close()

class kMean_parse:
    def __init__(self,path_txt):
        self.path = path_txt
        self.km = KMeans(n_clusters=5,init="k-means++",n_init=10,max_iter=3000000,tol=1e-3,random_state=0)
        self._load_data()

    def _load_data (self):
        self.data = np.loadtxt(self.path)

    def parse_data (self):
        self.y_k = self.km.fit_predict(self.data)
        print(self.km.cluster_centers_)

    def plot_data (self):
        plt.scatter(self.data[self.y_k == 0, 0], self.data[self.y_k == 0, 1], s=50, c="orange", marker="o", label="cluster 1")
        plt.scatter(self.data[self.y_k == 1, 0], self.data[self.y_k == 1, 1], s=50, c="green", marker="s", label="cluster 2")
        plt.scatter(self.data[self.y_k == 2, 0], self.data[self.y_k == 2, 1], s=50, c="blue", marker="^", label="cluster 3")
        plt.scatter(self.data[self.y_k == 3, 0], self.data[self.y_k == 3, 1], s=50, c="gray", marker="*",label="cluster 4")
        plt.scatter(self.data[self.y_k == 4, 0], self.data[self.y_k == 4, 1], s=50, c="yellow", marker="d",label="cluster 5")
       # draw the centers
        plt.scatter(self.km.cluster_centers_[:, 0], self.km.cluster_centers_[:, 1], s=250, marker="*", c="red", label="cluster center")
        plt.legend()
        plt.grid()
        plt.show()




if __name__ == '__main__':
     whtxt = create_w_h_txt("./voc/Annotations","./data1.txt") #指定为voc标注路径,以及存放生成文件路径
     whtxt.process_file()
     kmean_parse = kMean_parse("./data1.txt")#路径和生成文件相同。
     kmean_parse.parse_data()
     kmean_parse.plot_data() #绘图部分只支持五个簇,要增加,需要自家改代码即可

结果如下图所示:
聚类结果

以下是使用kmeans++算法计算voc数据集检测并生成anchors的长宽聚类结果的Python代码: ```python import numpy as np import xml.etree.ElementTree as ET import os # 定义聚类数量 num_clusters = 9 # 加载voc数据集中的所有bbox宽高数据 def load_bbox_data(data_path): bbox_data = [] for filename in os.listdir(data_path): if filename.endswith('.xml'): xml_path = os.path.join(data_path, filename) tree = ET.parse(xml_path) root = tree.getroot() for obj in root.iter('object'): bbox = obj.find('bndbox') width = int(bbox.find('xmax').text) - int(bbox.find('xmin').text) height = int(bbox.find('ymax').text) - int(bbox.find('ymin').text) bbox_data.append([width, height]) return np.array(bbox_data) # 使用kmeans++算法进行聚类计算 def kmeans_plus_plus(data, k): center_ids = [np.random.randint(len(data))] while len(center_ids) < k: distances = [] for point in data: distance = np.min(np.sum((point - data[center_ids]) ** 2, axis=1)) distances.append(distance) center_ids.append(np.argmax(distances)) centers = data[center_ids] while True: clusters = [[] for _ in range(k)] for point in data: distances = np.sum((point - centers) ** 2, axis=1) closest_cluster = np.argmin(distances) clusters[closest_cluster].append(point) new_centers = np.array([np.mean(cluster, axis=0) for cluster in clusters]) if np.allclose(new_centers, centers): break centers = new_centers return centers # 对聚类结果进行排序 def sort_clusters(centers): return centers[np.argsort(centers[:, 0] * centers[:, 1])] # 输出聚类结果 def print_clusters(centers): print('Anchors:') for i, anchor in enumerate(centers): print(f' - {anchor[0]:.2f}, {anchor[1]:.2f}') if __name__ == '__main__': data_path = '/path/to/voc/data' bbox_data = load_bbox_data(data_path) centers = kmeans_plus_plus(bbox_data, num_clusters) sorted_centers = sort_clusters(centers) print_clusters(sorted_centers) ``` 说明: 1. `load_bbox_data`函数用于加载voc数据集中的所有bbox宽高数据。 2. `kmeans_plus_plus`函数使用kmeans++算法进行聚类计算,其中`data`参数为输入数据,`k`参数为聚类数量。 3. `sort_clusters`函数对聚类结果进行排序,按照宽高乘积从小到大排序。 4. `print_clusters`函数输出聚类结果,按照宽高乘积从小到大输出每个anchor的宽和高。 5. 在`main`函数中,先加载voc数据集中的bbox宽高数据,然后使用kmeans++算法进行聚类计算,最后对聚类结果进行排序并输出。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值