Kd-Tree 普林斯顿 算法第四版

系列文章目录



前言

Kd-Tree这个是网课上第五周的课后作业,主要考察对于Kd-Tree的理解与应用。

详细的要求可以参照作业的链接:链接

如果遇到不会的可以参考答疑链接,这里包括了常见的问题以及老师推荐的做法


一、PointSET 分析与代码

PointSET需要建立基本题目要求的数据结构和API,该部分较简单,跟着要求完成即可

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;
import edu.princeton.cs.algs4.StdDraw;

public class PointSET {
    private final SET<Point2D> Points;

    public PointSET() {
        Points = new SET<Point2D>();
    }

    public boolean isEmpty() {
        return Points.isEmpty();
    }

    public int size() {
        return Points.size();
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        if (!Points.contains(p)) {
            Points.add(p);
        }
    }

    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        return Points.contains(p);
    }

    public void draw() {
        for (Point2D p : Points) {
            p.draw();
        }
        StdDraw.show();
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        SET<Point2D> inside_Points = new SET<Point2D>();
        for (Point2D p : Points) {
            if (rect.contains(p)) {
                inside_Points.add(p);
            }
        }
        return inside_Points;
    }

    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        Point2D point = null;
        double min_dis = Double.MAX_VALUE;
        for (Point2D p1 : Points) {
            double dis = p.distanceSquaredTo(p1);
            if (dis < min_dis) {
                point = p1;
                min_dis = dis;
            }
        }
        return point;
    }

    public static void main(String[] args) {

    }
}

二、KdTree

1.编写顺序

参考链接中的推荐编写顺序:
Start by writing isEmpty() and size(). These should be very easy. From there, write a simplified version of insert() which does everything except set up the RectHV for each node. Write the contains() method, and use this to test that insert() was implemented properly. Note that insert() and contains() are best implemented by using private helper methods analogous to those found on page 399 of the book or by looking at BST.java. We recommend using orientation as an argument to these helper methods.
Now add the code to insert() which sets up the RectHV for each Node. Next, write draw(), and use this to test these rectangles. Finish up KdTree with the nearest and range methods. Finally, test your implementation using our interactive client programs as well as any other tests you’d like to conduct.
即:
1.先完成isEmpty,size
2.随后完成insert(不设置RectHV)
3.完成contians
4.insert中设置RectHV
5.完成Draw
6.完成其他(range、nearest)

其中1-5的实现相对容易

2.range、nearst分析(翻译)

范围搜索。要查找给定查询矩形中包含的所有点,请从根开始,使用以下修剪规则在两个子树中递归搜索点:如果查询矩形与对应于节点的矩形不相交,则无需搜索该节点(或其子树)。仅当子树可能包含查询矩形中包含的点时,才会搜索该子树。

最近邻搜索。要查找与给定查询点最近的点,请从根开始,使用以下修剪规则在两个子树中递归搜索:如果迄今为止发现的最近点距离查询点和对应于节点的矩形之间的距离较近,则无需搜索该节点(或其子树)。也就是说,仅当某个节点可能包含比目前为止找到的最佳节点更接近的点时,才搜索该节点。修剪规则的有效性取决于快速找到附近的点。为此,请组织递归方法,以便在有两个可能的子树向下时,始终选择位于分割线同一侧的子树作为查询点,作为第一个子树来搜索找到的最近点,而搜索第一个子树可能会对第二个子树进行修剪。

3.实现代码

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;
import edu.princeton.cs.algs4.StdDraw;

public class KdTree {
    private Node root;
    private int size;

    private static class Node {
        private final Point2D p;      // the point
        private final RectHV rect;    // the axis-aligned rectangle corresponding to this node
        private Node lb;        // the left/bottom subtree
        private Node rt;        // the right/top subtree
        //private int N;          // count
        private final int depth;
        //private boolean vert;

        public Node(Point2D p, int depth, RectHV rect) {
            if (p == null)
                throw new NullPointerException();
            this.p = p;
            //this.N = N;
            this.rect = rect;
            this.depth = depth;
        }
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    public int size() {
        return size;
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        } else if (contains(p))
            return;
        else {
            RectHV Rect = new RectHV(0, 0, 1, 1);
            root = insert(root, p, 0, Rect);
            size++;
        }

    }

    private Node insert(Node x, Point2D p, int depth, RectHV rect) {
        if (x == null) {
            return new Node(p, depth, rect);//其他节点
        } else {
            //left,right
            if (depth % 2 == 0) {
                if (p.x() < x.p.x()) {
                    RectHV leftRect = new RectHV(rect.xmin(), rect.ymin(), x.p.x(), rect.ymax());
                    x.lb = insert(x.lb, p, depth + 1, leftRect);
                } else {
                    RectHV rightRect = new RectHV(x.p.x(), rect.ymin(), rect.xmax(), rect.ymax());
                    x.rt = insert(x.rt, p, depth + 1, rightRect);
                }
            }
            //top,bottom
            else {
                if (p.y() < x.p.y()) {
                    RectHV bottomRect = new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), x.p.y());
                    x.lb = insert(x.lb, p, depth + 1, bottomRect);
                } else {
                    RectHV topRect = new RectHV(rect.xmin(), x.p.y(), rect.xmax(), rect.ymax());
                    x.rt = insert(x.rt, p, depth + 1, topRect);
                }
            }
        }
        return x;
    }

    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        } else {
            if (root == null) {
                return false;
            } else {
                // 递归的写法
                return contains(p, root);
            }
        }
    }

    private int compare(Point2D p, Node n) {
        if (n.depth % 2 == 0) {
            // 如果是偶数层,按 x 比较
            if (Double.compare(p.x(), n.p.x()) == 0) {
                return Double.compare(p.y(), n.p.y());
            } else {
                return Double.compare(p.x(), n.p.x());
            }
        } else {
            // 按 y 比较
            if (Double.compare(p.y(), n.p.y()) == 0) {
                return Double.compare(p.x(), n.p.x());
            } else {
                return Double.compare(p.y(), n.p.y());
            }
        }
    }

    private boolean contains(Point2D p, Node node) {
        if (node == null) {
            return false;
        } else if (p.equals(node.p)) {
            return true;
        } else {
            if (compare(p, node) < 0) {
                return contains(p, node.lb);
            } else {
                return contains(p, node.rt);
            }
        }
    }

    public void draw() {
        for (Node node : Get_Nodes()) {


            if (node.depth % 2 == 0) {
                //vertical
                StdDraw.setPenRadius(0.005);
                StdDraw.setPenColor(StdDraw.RED);
                StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax());
            } else {
                //h
                StdDraw.setPenRadius(0.005);
                StdDraw.setPenColor(StdDraw.BLUE);
                StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y());
            }
            StdDraw.setPenColor(StdDraw.BLACK);
            StdDraw.setPenRadius(0.01);
            node.p.draw();
        }
        StdDraw.show();
    }

    private Iterable<Node> Get_Nodes() {
        Queue<Node> kNodes = new Queue<Node>();
        order(root, kNodes);
        return kNodes;
    }

    private void order(Node node, Queue<Node> q) {
        if (node == null) return;
        q.enqueue(node);
        order(node.lb, q);
        order(node.rt, q);
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        SET<Point2D> inside_Nodes = new SET<Point2D>();

        range(root, inside_Nodes, rect);
        return inside_Nodes;
    }

    private void range(Node node, SET<Point2D> points, RectHV rect) {
        if (rect.contains(node.p)) {
            points.add(node.p);
        }
        if (node.lb != null && node.rect.intersects(rect)) {
            range(node.lb, points, rect);  //左子树
        }

        if (node.rt != null && node.rect.intersects(rect)) {
            range(node.rt, points, rect);  //右子树
        }

    }

    public Point2D nearest(Point2D p) {
        if (p == null)
            throw new NullPointerException();
        if (root != null)
            return nearPoint(root, p, root).p;
        return null;
    }

    private Node nearPoint(Node kd, Point2D query, Node target) {
        if (kd == null) return target;
        double nrDist = query.distanceSquaredTo(target.p);//last find target distance
        double kdDist = query.distanceSquaredTo(kd.p);

        if (nrDist >= kdDist || nrDist >= kd.rect.distanceSquaredTo(query)) {

            if (nrDist > kdDist) target = kd;

            if (kd.depth % 2 == 0) {
                double cmpX = query.x() - kd.p.x();
                if (cmpX < 0.0) {
                    //左侧
                    if (kd.lb != null) target = nearPoint(kd.lb, query, target);
                    if (kd.rt != null) target = nearPoint(kd.rt, query, target);
                } else {
                    //右侧
                    if (kd.rt != null) target = nearPoint(kd.rt, query, target);
                    if (kd.lb != null) target = nearPoint(kd.lb, query, target);
                }
            } else {
                double cmpY = query.y() - kd.p.y();
                if (cmpY < 0.0) {
                    //下侧
                    if (kd.lb != null) target = nearPoint(kd.lb, query, target);
                    if (kd.rt != null) target = nearPoint(kd.rt, query, target);
                } else {
                    //上侧
                    if (kd.rt != null) target = nearPoint(kd.rt, query, target);
                    if (kd.lb != null) target = nearPoint(kd.lb, query, target);
                }
            }
        }
        return target;
    }

    public static void main(String[] args) {

    }
}

总结

本代码参考了的nearst实现参考了链接
本代码最终评分为81/100,刚好通过。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值