Kd-Trees
PointSET.java
暴力法没啥好说的,红黑树都用的现成的,照着API写就完事儿了,注意异常的抛出。
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
public class PointSET {
private final Set<Point2D> points;
public PointSET() {
points = new TreeSet<>();
}
public boolean isEmpty() {
return points.isEmpty();
}
public int size() {
return points.size();
}
public void insert(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
points.add(p);
}
public boolean contains(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
return points.contains(p);
}
public void draw() {
for (Point2D point : points) {
point.draw();
}
}
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new IllegalArgumentException();
}
List<Point2D> res = new ArrayList<>();
for (Point2D p : points) {
if (rect.contains(p)) {
res.add(p);
}
}
return res;
}
public Point2D nearest(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
if (isEmpty()) {
return null;
}
double minDis = Double.POSITIVE_INFINITY;
Point2D min = null;
for (Point2D point : points) {
double dis = p.distanceSquaredTo(point);
if (dis < minDis) {
minDis = dis;
min = point;
}
}
return min;
}
}
KdTree.java
range()的寻找策略:如果根据当前结点划分出来的左侧(即左/下平面)矩形区域与查询区域有交集,说明左侧区域中可能有点落在查询区域中,需要向左子树进行递归;同理,判断是否需要向右子树递归。
nearest()的剪枝策略:每一次递归,都先向离查询点更接近的左/右区域进行递归,并更新最近邻点,此时判断查询点到剩余的右/左区域的距离,是否小于查询点到更新后的最近邻点的距离,如果是,说明在剩余区域中仍可能有点到查询点的距离更近,需要向右/左区域进行递归;如果不是,说明剩余区域中所有的点到查询点的距离都大于当前已得到的最短距离,无需再向右/左区域递归,即将这一枝剪去。
(点到矩形的距离,指的是该点到矩形四边上的最短距离,按两者的相对位置关系会有两种情况:1. 点到最近边的垂直距离;2. 点到矩形最近顶点的距离)
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
import java.util.ArrayList;
import java.util.List;
public class KdTree {
private static final boolean VER = true;
private static final boolean HOR = false;
private Node root;
private int size;
public KdTree() {}
public boolean isEmpty() {
return size == 0;
}
public int size() {
return size;
}
public void insert(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
root = insert(root, VER, p, 0, 0, 1, 1);
}
private Node insert(Node x, boolean direction, Point2D p, double xmin, double ymin, double xmax, double ymax) {
if (x == null) {
size++;
return new Node(p, new RectHV(xmin, ymin, xmax, ymax));
}
// 特殊情况,当已经存在相同的点时不插入
if (p.equals(x.point)) {
return x;
}
// 每次递归都要保存对应的矩形区域
if (direction == VER) {
if (p.x() < x.point.x()) {
x.left = insert(x.left, HOR, p, xmin, ymin, x.point.x(), ymax);
} else {
x.right = insert(x.right, HOR, p, x.point.x(), ymin, xmax, ymax);
}
} else {
if (p.y() < x.point.y()) {
x.left = insert(x.left, VER, p, xmin, ymin, xmax, x.point.y());
} else {
x.right = insert(x.right, VER, p, xmin, x.point.y(), xmax, ymax);
}
}
return x;
}
public boolean contains(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
Node x = root;
boolean direction = VER;
while (x != null) {
if (p.equals(x.point)) {
return true;
}
if ((direction == VER && p.x() < x.point.x())
|| (direction == HOR && p.y() < x.point.y())) {
x = x.left;
} else {
x = x.right;
}
direction = !direction;
}
return false;
}
public void draw() {
draw(root, VER);
}
// 注意画点和线时颜色、粗细的差别
private void draw(Node x, boolean direction) {
if (x == null) {
return;
}
// 画点
StdDraw.setPenRadius(0.01);
StdDraw.setPenColor(StdDraw.BLACK);
x.point.draw();
// 画线
StdDraw.setPenRadius();
if (direction == VER) {
StdDraw.setPenColor(StdDraw.RED);
StdDraw.line(x.point.x(), x.rect.ymin(), x.point.x(), x.rect.ymax());
} else {
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.line(x.rect.xmin(), x.point.y(), x.rect.xmax(), x.point.y());
}
draw(x.left, !direction);
draw(x.right, !direction);
}
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new IllegalArgumentException();
}
List<Point2D> list = new ArrayList<>();
// 只有当root非空时才调用,因为调用函数递归时没有对null的判断
if (!isEmpty()) {
range(root, rect, list);
}
return list;
}
private void range(Node x, RectHV rect, List<Point2D> list) {
if (rect.contains(x.point)) {
list.add(x.point);
}
// 只有当左右子树对应的矩形区域与查询区域有交集时,才可能有点落在查询区域里
if (x.left != null && x.left.rect.intersects(rect)) {
range(x.left, rect, list);
}
if (x.right != null && x.right.rect.intersects(rect)) {
range(x.right, rect, list);
}
}
public Point2D nearest(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
// 特殊情况要排除
if (isEmpty()) {
return null;
}
return nearest(root, VER, p, root.point);
}
// 每次调用都返回更新后的最近邻点
private Point2D nearest(Node x, boolean direction, Point2D p, Point2D neighbor) {
if (x == null) {
return neighbor;
}
if (p.equals(x.point)) {
return x.point;
}
// 先对当前点进行判断,并更新最近邻点
if (p.distanceSquaredTo(x.point) < p.distanceSquaredTo(neighbor)) {
neighbor = x.point;
}
// 每次都先向更靠近查询点的子树区域进行递归
// 只有当查询点到剩余子树的矩形区域的距离小于到更新后的最近邻点的距离时,才向剩余子树递归
if ((direction == VER && p.x() < x.point.x())
|| (direction == HOR && p.y() < x.point.y())) {
neighbor = nearest(x.left, !direction, p, neighbor);
if (x.right != null
&& x.right.rect.distanceSquaredTo(p) < p.distanceSquaredTo(neighbor)) {
neighbor = nearest(x.right, !direction, p, neighbor);
}
} else {
neighbor = nearest(x.right, !direction, p, neighbor);
if (x.left != null
&& x.left.rect.distanceSquaredTo(p) < p.distanceSquaredTo(neighbor)) {
neighbor = nearest(x.left, !direction, p, neighbor);
}
}
return neighbor;
}
private class Node {
Point2D point;
RectHV rect; // 以当前点为根的子树所对应的总矩形区域
Node left, right;
Node(Point2D point, RectHV rect) {
this.point = point;
this.rect = rect;
}
}
}