K-D树的c++代码实现

先贴上代码,具体原理后面再逐步完善

kdtree.h:

#ifndef KDTREE_H_
#define KDTREE_H_

#include<cmath>
#include<algorithm>
#include<stack>
#include<vector>
#include<iostream>
template<typename T>
class KdTree
{
    //定义节点结构
    //----------
    struct kdNode
    {
        std::vector<T> vec;
        int splitAttribute;
        kdNode *lChild;
        kdNode *rChild;
        kdNode *parent;

        kdNode(std::vector<T> v = {}, int split = 0, kdNode *lCh = nullptr,
            kdNode *rCh = nullptr, kdNode *par = nullptr) :vec(v), splitAttribute(split),
            lChild(lCh), rChild(rCh), parent(par) {}
    };

public:
    KdTree()
    {
        root = nullptr;
    }

    kdNode *getRoot()
    {
        return root;
    }

    std::vector<T> getRootData()
    {
        return root->vec;
    }
    //嵌套型数据结构
    //------------
    KdTree(std::vector<std::vector<T>> &data)
    {
        root = createKdTree(data);
    }

    //转置矩阵
    //-------
    std::vector<std::vector<T>> transpose(std::vector<std::vector<T>> &data)
    {

        int m = data.size();
        int n = data[0].size();
        std::vector<std::vector<T>> trans(n, std::vector<T>(m, 0));
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < m; j++)
            {
                trans[i][j] = data[j][i];
            }
        }
        return trans;
    }

    //计算每个方向上的方差
    //-----------------
    double getVariance(std::vector<T> &vec)
    {
        int n = vec.size();
        double sum = 0;
        for (int i = 0; i < n; i++)
        {
            sum += vec[i];
        }
        double avg = sum / n;
        sum = 0;
        for (int i = 0; i < n; i++)
        {
            sum += pow(vec[i] - avg, 2);//#include<cmath>
        }
        return sum / n;
    }

    //根据最大方差确定垂直于超平面的轴序号split attribute
    //-----------------------------------------
    int getSplitAttribute(std::vector<std::vector<T>> &data)
    {
        int size = data.size();
        int splitAttribute = 0;
        double maxVar = getVariance(data[0]);
        for (int i = 1; i < size; i++)
        {
            double temp = getVariance(data[i]);
            if (temp > maxVar)
            {
                splitAttribute = i;
                maxVar = temp;
            }
        }
        return splitAttribute;
    }

    //查询中值
    //-------
    T getSplitValue(std::vector<T> &vec)
    {
        std::sort(vec.begin(), vec.end());
        return vec[vec.size() / 2];
    }

    //计算2个k维点的距离
    //---------------
    static double getDistance(std::vector<T> &v1, std::vector<T> &v2)
    {
        double sum = 0;
        for (size_t i = 0; i < v1.size(); i++)
        {
            sum += pow(v1[i] - v2[i], 2);
        }
        return sqrt(sum);
    }

    //创建kd-tree
    //-----------
    kdNode *createKdTree(std::vector<std::vector<T>> &data)
    {
        //cout << "create_1" << endl;
        if (data.empty())
            return nullptr;

        int n = data.size();
        if (n == 1)
        {
            return new kdNode(data[0], -1);
        }

        //获得轴序号与值
        //------------
        std::vector<std::vector<T>> data_T = transpose(data);
        int splitAttribute = getSplitAttribute(data_T);
        int splitValue = getSplitValue(data_T[splitAttribute]);

        //分割数据空间:根据attribute和value
        //------------------------------
        std::vector<std::vector<T>> left;
        std::vector<std::vector<T>> right;

        int flag = 0;
        kdNode *splitNode = nullptr;
        for (int i = 0; i < n; i++)
        {
            if (flag == 0 && data[i][splitAttribute] == splitValue)
            {
                splitNode = new kdNode(data[i]);
                splitNode->splitAttribute = splitAttribute;
                flag = 1;
                continue;
            }
            if (data[i][splitAttribute] <= splitValue)
            {
                left.push_back(data[i]);
            }
            else
            {
                right.push_back(data[i]);
            }
        }

        splitNode->lChild = createKdTree(left);
        splitNode->rChild = createKdTree(right);
        return splitNode;
    }

    //-----------------------------最邻近算法------------------------------------
    //-------------------------------------------------------------------------

    //指定起点查询
    //----------
    std::vector<T> searchNearestNeighbor(std::vector<T> &target, kdNode *start)
    {
        std::vector<T> NN = { 0,0 };//给定一个初始值
        std::stack<kdNode *> searchPath;
        kdNode *p = start;
         
        if (p != nullptr)
        {
            while (p->splitAttribute != -1) //-1是指已到达边缘点,没有分割属性
            {
                searchPath.push(p);
                int splitAttribute = p->splitAttribute;
                if (target[splitAttribute] <= p->vec[splitAttribute])
                {
                    p = p->lChild;
                }
                else
                {
                    p = p->rChild;
                }
            }
            NN = p->vec;
        }

        double mindist = KdTree::getDistance(target, NN);
        kdNode *curNode;
        double dist;
        std::vector<T> nn;
        while (!searchPath.empty())
        {
            curNode = searchPath.top();
            searchPath.pop();
            dist = KdTree::getDistance(target, curNode->vec);

            if (dist < mindist)
            {
                mindist = dist;
                NN = curNode->vec;

                //判断以target为中心,以dist为半径的球是否和节点的超平面相交
                if (curNode->vec[curNode->splitAttribute] >= target[curNode->splitAttribute] - dist
                    && curNode->vec[curNode->splitAttribute] <= target[curNode->splitAttribute] + dist)
                {
                    if (target[curNode->splitAttribute] > curNode->vec[curNode->splitAttribute])
                    {
                        nn = searchNearestNeighbor(target, curNode->lChild);
                    }
                    else
                    {
                        nn = searchNearestNeighbor(target, curNode->rChild);
                    }

                    if (KdTree::getDistance(target, nn) < KdTree::getDistance(target, NN))
                    {
                        NN = nn;
                    }
                }
            }
            else
            {
                if (curNode->vec[curNode->splitAttribute] >= target[curNode->splitAttribute] - mindist
                    && curNode->vec[curNode->splitAttribute] <= target[curNode->splitAttribute] + mindist)
                {
                    if (target[curNode->splitAttribute] > curNode->vec[curNode->splitAttribute])
                    {
                        nn = searchNearestNeighbor(target, curNode->lChild);
                    }
                    else
                    {
                        nn = searchNearestNeighbor(target, curNode->rChild);
                    }

                    if (KdTree::getDistance(target, nn) < KdTree::getDistance(target, NN))
                    {
                        NN = nn;
                    }
                }
            }
        }
        return NN;
    }

    //从根节点进行查询
    //-------------
    std::vector<T> searchNearestNeighbor(std::vector<T> &target)
    {
        std::vector<T> NN;
        NN = searchNearestNeighbor(target, root);
        return NN;
    }

    //打印kdTree
    //----------
    void printTree(kdNode *root)
    {
        std::cout << "[";
        if (root->lChild)
        {
            std::cout << "left:";
            printTree(root->lChild);
        }

        if (root)
        {
            std::cout << "(";
            for (size_t i = 0; i < root->vec.size(); i++)
            {
                std::cout << root->vec[i];
                if (i != (root->vec.size() - 1))
                    std::cout << ",";
            }
            std::cout << ")";
        }

        if (root->rChild)
        {
            std::cout << "right";
            printTree(root->rChild);
        }
        std::cout << "]";
    }

private:
    kdNode * root;
};

#endif // !KDTREE_H_

main.cpp:

#include"KDTree.h"
using std::vector;
using std::cout;
using std::endl;
int main()
{
    double data[6][2] = { { 2,3 },{ 5,4 },{ 9,6 },{ 4,7 },{ 8,1 },{ 7,2 } };
    vector<vector<double>> train(6, vector<double>(2, 0));
    for (unsigned int i = 0; i < 6; i++)
    {
        for (unsigned int j = 0; j < 2; j++)
        {
            train[i][j] = data[i][j];
        }
    }
    KdTree<double> *Tree = new KdTree<double>(train);  

    //输出整棵树
    Tree->printTree(Tree->getRoot());
    cout << endl;
    cout << endl;

    //输出根节点
    vector<double> root = Tree->getRootData();
    vector<double>::iterator r = root.begin();
    cout << "root=";
    while (r != root.end())
        cout << *r++ << ",";

    //查找最近点
    cout << endl;
    cout << endl;
    vector<double> goal;
    double i, j;
    i = 9.0;
    j = 5.0;
    goal.push_back(i);
    goal.push_back(j);
    vector<double> nearestNeighbor = Tree->searchNearestNeighbor(goal);
    vector<double>::iterator beg = nearestNeighbor.begin();
    cout << endl;
    cout << "(" << i << "," << j << ") nearest neighbor is: ";
    while (beg != nearestNeighbor.end())
        cout << *beg++ << ",";
    cout << endl;
    return 0;
}

参考:
https://www.cnblogs.com/jingrui/p/10469601.html

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的C++实现k-means聚类算法的示例代码: ```c++ #include <iostream> #include <vector> #include <cmath> #include <random> //定义一个点的结构体 struct Point { double x, y; int cluster; //点所属簇的编号 }; //计算两个点之间的距离 double distance(Point& a, Point& b) { return std::sqrt(std::pow(a.x - b.x, 2.0) + std::pow(a.y - b.y, 2.0)); } //随机生成k个初始聚类中心 void initClusterCenter(std::vector<Point>& points, std::vector<Point>& clusters, int k) { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<int> dis(0, points.size() - 1); for (int i = 0; i < k; ++i) { Point p = points[dis(gen)]; p.cluster = i; clusters.push_back(p); } } //将每个点分配到距离最近的聚类中心所在的簇 void assignCluster(std::vector<Point>& points, std::vector<Point>& clusters) { for (auto& p : points) { double minDistance = distance(p, clusters[0]); int clusterIndex = 0; for (int i = 1; i < clusters.size(); ++i) { double d = distance(p, clusters[i]); if (d < minDistance) { minDistance = d; clusterIndex = i; } } p.cluster = clusterIndex; } } //重新计算每个簇的中心点 void updateClusterCenter(std::vector<Point>& points, std::vector<Point>& clusters) { for (auto& c : clusters) { double sumX = 0.0, sumY = 0.0; int count = 0; for (auto& p : points) { if (p.cluster == c.cluster) { sumX += p.x; sumY += p.y; ++count; } } c.x = sumX / count; c.y = sumY / count; } } //判断聚类是否已经收敛 bool isConverged(std::vector<Point>& oldClusters, std::vector<Point>& newClusters, double epsilon) { for (int i = 0; i < oldClusters.size(); ++i) { if (distance(oldClusters[i], newClusters[i]) > epsilon) { return false; } } return true; } //k-means聚类算法 std::vector<Point> kMeans(std::vector<Point>& points, int k, double epsilon, int maxIterations) { std::vector<Point> clusters; initClusterCenter(points, clusters, k); int iter = 0; while (true) { assignCluster(points, clusters); std::vector<Point> newClusters = clusters; updateClusterCenter(points, newClusters); ++iter; if (isConverged(clusters, newClusters, epsilon) || iter >= maxIterations) { return newClusters; } clusters = newClusters; } } int main() { //生成一些随机点 std::vector<Point> points; for (int i = 0; i < 100; ++i) { Point p; p.x = std::rand() % 100; p.y = std::rand() % 100; points.push_back(p); } //运行k-means聚类 std::vector<Point> clusters = kMeans(points, 3, 0.01, 100); //打印每个簇中的点 for (auto& c : clusters) { std::cout << "Cluster " << c.cluster << ":\n"; for (auto& p : points) { if (p.cluster == c.cluster) { std::cout << "(" << p.x << "," << p.y << ")\n"; } } } return 0; } ``` 这是一个简单的示例,更复杂的应用可能需要更多的优化和调整。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值