K-d(k-dimension)Tree原理及实现(Java)

  使用BST存储2D数据有天然的缺陷,只能比较一个维度的信息,从而造成信息丢失。通常使用四叉树(QuadTree,按照空间划分区域)存储2D数据,使用八叉树(OctTree,按照空间划分区域)存储3D数据。
  K-d(k-dimension)Tree是以二叉树的形式存储数据的,但它使用特殊的方式划分空间,理论上可以存储多个维度的数据,并且在最近邻搜索算法nearest()上通过剪枝,获得了较高的效率。
  下面以2D kdTree为例,多维kdTree在此基础上扩展即可。
满足条件
  root节点将整个空间分为左和右(即在这一层按照x坐标来比较);
  下一层的节点按照与上层节点相反的方式划分空间。


Operation

	public void addNode(Node newNode); // 向kdTree添加新的节点
	public Node nearest(double x, double y); // 返回kdTree中距离坐标A(x, y)最近的节点

Node结构

	private int value;
    private double x;  // 节点相对于root节点的坐标
    private double y;
    private boolean direction; // 此节点划分空间的方向, true:左右(x轴),false:上下(y轴)
    private Node left;   // 左子树和右子树
    private Node right;

Nearest(int x, int y)思路
  1.先从当前节点(最初是root节点)开始,计算它和A(x, y)的距离,将其和已知的最近节点进行比较;
  2.找到当前节点左右子树中较好的一侧goodSide(A所在区域),较差的一侧badSide;
  3.在goodSide中进行search,更新最近节点;
  4.如果badSide中任何节点的距离都比最近节点大,则略过它,否则search badSide;
  注释:
    步骤2之所以区分goodSide和badSide是因为从理论上来说,goodSide的一侧大概率存在最近节点;步骤3一定要先于步骤4完成,只有更新了goodSide中的最近节点后,才会对badSide进行有效的剪枝。


java代码
  使用如下所示的二维空间节点为例
在这里插入图片描述
  其kdTree的结构如下所示
在这里插入图片描述

  KdTree.java

public class KdTree {
    /* 2维kdTree
     */
    private Node root;

    KdTree(int value) {
        root = new Node(value, 0, 0, true); // root节点初始化
    }

    public void addNode(Node newNode) {
        /* 向kdTree添加新的节点
         */
        Node curNode = root;
        addNodeHelper(root, newNode);
    }

    public Node nearest(double x, double y) {
        /* 先从当前节点(最初是root节点)开始,计算它和A(x, y)的距离,将其和已知的最近节点进行比较
         * 找到当前节点左右子树中较好的一侧goodSide(A所在区域),较差的一侧badSide;
         * 在goodSide中进行search,更新最近节点;
         * 如果badSide中任何节点的距离都比最近节点大,则略过它,否则search badSide;
         */
        return nearestHelper(root, x, y, root);
    }

    private Node nearestHelper(Node curNode, double x, double y, Node nearNode) {
        if (distance(curNode, x, y) < distance(nearNode, x, y)) {
            nearNode = curNode;
        }

        Node[] sides = getSides(curNode, x, y, nearNode);
        Node goodSide = getGoodSide(sides);
        Node badSide = getBadSide(sides);

        if (goodSide != null) {
            nearNode = nearestHelper(goodSide, x, y, nearNode);
        }

        if (badSide != null) {
            if (curNode.isDirection() && Math.abs(badSide.getY() - y) < distance(nearNode, x, y)) {
                nearNode = nearestHelper(badSide, x, y, nearNode);
            } else if (!curNode.isDirection() && Math.abs(badSide.getX() - x) < distance(nearNode, x, y)) {
                nearNode = nearestHelper(badSide, x, y, nearNode);
            }
        }

        return nearNode;
    }

    private Node[] getSides(Node curNode, double x, double y, Node nearNode) {
        Node[] res = new Node[2];
        Node left = curNode.getLeft();
        Node right = curNode.getRight();
        if (curNode.isDirection()) {
            if (curNode.getX() > x) {
                getRes(left, right, res);
            } else {
                getRes(right, left, res);
            }
        } else {
            if (curNode.getY() > y) {
                getRes(right, left, res);
            } else {
                getRes(left, right, res);
            }
        }
        return res;
    }

    private void getRes(Node a, Node b, Node[] res) {
        res[0] = a;
        res[1] = b;
    }

    private Node getGoodSide(Node[] sides) {
        return sides[0];
    }

    private Node getBadSide(Node[] sides) {
        return sides[1];
    }

    private double distance(Node node, double x, double y) {
        double lenX = node.getX() - x;
        double lenY = node.getY() - y;
        return Math.sqrt(Math.pow(lenX, 2) + Math.pow(lenY, 2));
    }

    private void addNodeHelper(Node curNode, Node newNode) {
        /* 根据当前节点划分的区域添加newNode
         */
        if (curNode.isDirection()) {
            leftOrRight(curNode.getX(), newNode.getX(), curNode, newNode);
        } else {
            leftOrRight(newNode.getY(), curNode.getY(), curNode, newNode);
        }
    }

    private void leftOrRight(double a, double b, Node cur, Node newNode) {
        /* 判断newnode应该添加到cur节点的左子树还是右子树
         */
        if (a > b) {
            leftSide(cur, newNode);
        } else {
            rightSide(cur, newNode);
        }
    }

    private void leftSide(Node cur, Node newNode) {
        /* newNode应该添加到cur节点的左子树上
         * 如果cur节点左子树为空,直接setLeft为newNode
         * 否则,调用helper方法进行添加
         */
        if (checkLeftEmpty(cur)) {
            cur.setLeft(newNode);
            newNode.setDirection(cur.isDirection());
        } else {
            addNodeHelper(cur.getLeft(), newNode);
        }
    }

    private boolean checkLeftEmpty(Node node) {
        return node.getLeft() == null;
    }

    private void rightSide(Node cur, Node newNode) {
        /* newNode应该添加到cur节点的右子树上
         * 如果cur节点右子树为空,直接setRight为newNode
         * 否则,调用helper方法进行添加
         */
        if (checkRightEmpty(cur)) {
            cur.setRight(newNode);
            newNode.setDirection(cur.isDirection());
        } else {
            addNodeHelper(cur.getRight(), newNode);
        }
    }

    private boolean checkRightEmpty(Node node) {
        return node.getRight() == null;
    }
}

  Node.java

public class Node {
    private int value;
    private double x;  // 节点相对于root节点的坐标
    private double y;
    private boolean direction; // 此节点划分空间的方向, true:左右(x轴),false:上下(y轴)
    private Node left;   // 左子树和右子树
    private Node right;

    Node(int v, double x, double y, boolean dir) {
        this.value = v;
        this.x = x;
        this.y = y;
        this.direction = dir;
        this.left = null;
        this.right = null;
    }

    public int getValue() {
        return value;
    }

    public double getX() {
        return x;
    }

    public double getY() {
        return y;
    }

    public boolean isDirection() {
        return direction;
    }

    public Node getLeft() {
        return left;
    }

    public Node getRight() {
        return right;
    }

    public void setLeft(Node left) {
        this.left = left;
    }

    public void setRight(Node right) {
        this.right = right;
    }

    public void setDirection(boolean parentDir) {
        this.direction = !parentDir;
    }
}



To be a sailor of the world bound for all ports.
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
好的,下面是一个简单的Python实现k-means算法的案例: ```python import numpy as np import matplotlib.pyplot as plt def kMeans(dataSet, k): m, n = dataSet.shape centroids = np.zeros((k, n)) for i in range(k): index = int(np.random.uniform(0, m)) centroids[i, :] = dataSet[index, :] clusterChanged = True while clusterChanged: clusterChanged = False clusterAssment = np.zeros((m, 2)) for i in range(m): minDist = np.inf minIndex = -1 for j in range(k): dist = np.sqrt(np.sum(np.power(dataSet[i, :] - centroids[j, :], 2))) if dist < minDist: minDist = dist minIndex = j if clusterAssment[i, 0] != minIndex: clusterChanged = True clusterAssment[i, :] = minIndex, minDist ** 2 for j in range(k): pointsInCluster = dataSet[np.nonzero(clusterAssment[:, 0] == j)] centroids[j, :] = np.mean(pointsInCluster, axis=0) return centroids, clusterAssment def show(dataSet, k, centroids, clusterAssment): m, n = dataSet.shape if n != 2: print("Dimension of dataSet should be 2!") return colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] for i in range(k): pointsInCluster = dataSet[np.nonzero(clusterAssment[:, 0] == i)] plt.scatter(pointsInCluster[:, 0], pointsInCluster[:, 1], marker='o', c=colors[i % len(colors)], alpha=0.5) plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', c='k', s=100, linewidths=3) plt.title('k-means') plt.xlabel('X') plt.ylabel('Y') plt.show() if __name__ == '__main__': data = np.random.rand(200, 2) k = 3 centroids, clusterAssment = kMeans(data, k) show(data, k, centroids, clusterAssment) ``` 该案例中,我们首先随机生成一组数据,然后输入数据和k值,即可得到聚类结果。其中,函数`kMeans()`实现了k-means算法,函数`show()`用于绘制聚类结果的图表。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

carpe~diem

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值