这是该系列的第三篇。一维、二维的最邻近点对问题(Closest-Pair Problem):戳这里
PS:建议先快速浏览低维的分治解法,因为各维度的解决思路具有高度的关联性。如果还没有弄清二维解法仍要执意往下看,最好先给头发买好保险。
1 问题描述
对于三维点 p i = ( x i 1 , x i 2 , x i 3 ) p_i = (x_{i1}, x_{i2}, x_{i3}) pi=(xi1,xi2,xi3),任意两点的欧几里得距离为 ∑ k = 1 3 ( x i k − x j k ) 2 \sqrt{\sum_{k=1}^3(x_{ik} - x_{jk})^2} ∑k=13(xik−xjk)2。
暴力算法与二维完全一致,枚举所有可能的点对,计算距离并选择最小距离的点对返回。它的时间复杂度是 O ( n 2 ) O(n^2) O(n2)。
同样,借助分治思想可以把复杂度降到 O ( n l o g n ) O(nlogn) O(nlogn)。
2 算法描述
2.1 开始前
首先对于一个乱序的点集 P P P,我们需要分别对 x x x和 z z z进行排序,得到 P x P_x Px和 P z P_z Pz。前者用于divide,后者用于merge。不对 y y y进行排序的原因是,我们将通过切片来减少这个维度上的重复工作。
2.2 Divide步骤
这一步与二维情况完全一致,只不过按 y y y的情况变为了按 z z z。如理解二维的Divide,可跳过。
第一步:拆分
对于任意的点集 P P P,其中点的数量 ∣ P ∣ = n \vert P\vert = n ∣P∣=n,类似于二维情况,记 ⌊ n / 2 ⌋ \lfloor n/2\rfloor ⌊n/2⌋为 m i mi mi,可以在 P x P_x Px中第 m i mi mi个点的坐标(记为 x ∗ x^* x∗)处,将其分为 Q Q Q区与 R R R区。其中坐标小于等于 x ∗ x^* x∗的点放入 Q Q Q区,大于的点放入 R R R区。
准确来说,分区的标准是点在 P x P_x Px中的位置,小于等于 m i mi mi划入 Q Q Q区,反之 R R R区,不过按 x ∗ x^* x∗来理解会比较方便。
第二步: 维护有序数组
对于 Q Q Q区和 R R R区,我们想要维护4个有序数组:按 x x x或 z z z排好序的 Q x Q_x Qx, Q z Q_z Qz, R x R_x Rx, R z R_z Rz。其中,按 x x x排序的数组, Q x Q_x Qx和 R x R_x Rx,按索引就可以分开了。
而 Q z Q_z Qz和 R z R_z Rz,需要新建数组,顺次遍历 P z P_z Pz并判断从属的区域来构建,花费 O ( n ) O(n) O(n)时间,且它们自然就是有序的。注意,如果你想通过 Q x Q_x Qx和 R x R_x Rx重排序得到,会影响整个算法的复杂度。
现在原问题就变成了两个完全相同的子问题。最终只有小于3个点时,可以直接比较每个点对得出答案。
2.3 Merge步骤
思路与二维是完全一致的。
假设Q区找到了局部最邻近点 { q 1 , q 2 } \{q_1, q_2\} {q1,q2},距离为 δ 1 \delta_1 δ1,同理R区找到了 { r 1 , r 2 } \{r_1, r_2\} {r1,r2},距离为 δ 2 \delta_2 δ2。并且在所有被分割线割离的点对中,即所有点对 { p i , p j } \{p_{i}, p_{j}\} {pi,pj}且 p i ∈ Q ∧ p j ∈ R p_i \in Q \wedge p_j \in R pi∈Q∧pj∈R,有一最小距离 δ 3 \delta_3 δ3。现在对于原问题的解,有完全相同于一维的3种情况:
- 最小距离的点对在被分割线割离的点对集合中,即 δ 3 \delta_3 δ3是最小,应返回对应被割离的点对。
- Q Q Q区中找到的点对距离最小,应返回 { q 1 , q 2 } \{q_1, q_2\} {q1,q2}。
- R R R区中找到的点对距离最小,应返回 { r 1 , r 2 } \{r_1, r_2\} {r1,r2}。
算法的关键点仍然是找出 δ 3 \delta_3 δ3。温馨提示:实现下面的关键算法时,请备好一杯茶或者是一包烟。
2.3.1 构建两个 S S S区
令 δ = m i n ( δ 1 , δ 2 ) \delta = min(\delta_1, \delta_2) δ=min(δ1,δ2),把 x ∗ x^* x∗左右宽为 δ \delta δ的区域记为 S S S区。其中,左半部分是 S 1 S_1 S1区,有半部分是 S 2 S_2 S2区。我们只需要检查 S S S内左右半区的点的组合即可。那么,是否可以利用 P z P_z Pz像二维那样只用检查一部分点对呢?
或许你已经猜到了,多增加的一个维度让二维下的这种算法完全失效。考虑 y y y方向这个维度,仍有可能需要检查所有点对。我们必须对y再次进行分割。
2.3.2 对 y y y进一步分区
如下图所示,我们统计 y y y方向上覆盖的范围,并错位分区(大的红色分区,看图应该就很好理解了吧),然后编号。很容易发现,处于 S 1 [ i ] S_1[i] S1[i]内的一个点,只需要与 S 2 [ i − 1 ] S_2[i-1] S2[i−1]与 S 2 [ i ] S_2[i] S2[i]两个区域内的点配对并进行比较。反过来也有这样的一对二的关系。超出的话,仅 y y y方向就会大于 δ \delta δ。
接下来我们就可以利用鸽巢原理(Pigeonhole Principle)来减少 z z z维度上的检查次数。假设现在要检查点 s s s , 它 属 于 S 1 ,它属于S_1 ,它属于S1区。以它为起始高度取出对应子区间内高度为 δ \delta δ的长方体(如上图),将它分解为16个边长为 δ / 2 \delta/2 δ/2的正方体box,如下图所示:
那么每个box内最大距离 3 δ / 2 < δ \sqrt{3}\delta/2<\delta 3δ/2<δ,即每个box内之多存在一点。那么,对于点 s s s,仅需要检查对应区域中,比它稍大的16个点之间的距离即可。超出16个点后,距离一定会超过这块区域的高度 δ \delta δ,因而可以不用考虑。(这里还需要检查比它稍小的16点,即双向检查,见下)
2.3.3 位置记录
理解算法后,实现时却发现:在分离的 S 1 S_1 S1区 S 2 S_2 S2区中,我怎么知道在另一个区域里,比点 s s s稍高的点是从哪里开始的?所以我们还需要记录两个参考位置。在 S 1 [ i ] S_1[i] S1[i]区中的点 s s s,记录它在 S 2 [ i − 1 ] S_2[i-1] S2[i−1]与 S 2 [ i ] S_2[i] S2[i]中的参考位置,这样在寻找16个点时,就知道起始位置了。
2.3.4 单向检查与双向检查
二维中仅需要单向检查的原因是,遍历点时并未区分 S 1 S_1 S1与 S 2 S_2 S2,每个可能的点对都会被检查到。三维有分 S 1 S_1 S1与 S 2 S_2 S2,如果要类比二维实现单边检查,需要在检查 S 2 S_2 S2区中的点时,反向去找对应的 S 1 S_1 S1。这意味着位置记录需要存储两个方向。
当然,这样实现不是不可以。不过,这里选择更节省空间的做法,只存储 S 1 S_1 S1中的点对应的 S 2 S_2 S2的信息,以及对应时的位置记录。我们仅固定 S 1 S_1 S1中的点去找匹配的、另一区域中的点,且此时向上查找16个,还要向下查找16个。
确保你理解了:不区分 S 1 S_1 S1与 S 2 S_2 S2 → \rightarrow →向上检查即可,区分并仅检查 S 1 S_1 S1中的点 → \rightarrow →需要向上向下双向检查。否则会使整个算法得出不正确的结果!(二维同理)
特例:区分并仅检查 S 1 S_1 S1中的点,单向检查 s ∗ s^* s∗时,将会漏掉红色的点对。不区分时则不会有任何问题。
2.3.5 总结下具体做法
先确定 y y y的最小值用于分区。从小到大遍历有序的 P z P_z Pz,如果点距 x ∗ x^* x∗大于 δ \delta δ,它不属于 S S S;否则:
- 如果这个点 s s s在 S 1 S_1 S1区,加入 S 1 S_1 S1数组。算出它属于 S 1 S_1 S1的第 i i i个子段,那么它就应该对应 S 2 [ i − 1 ] S_2[i-1] S2[i−1]与 S 2 [ i ] S_2[i] S2[i]。现在 l e n ( S 2 [ i − 1 ] ) len(S_2[i-1]) len(S2[i−1])与 l e n ( S 2 [ i ] ) len(S_2[i]) len(S2[i])即是需要记录的参考位置。因为目前这两个区段中的点均低于 s s s。
- 如果这个点 s s s在 S 2 S_2 S2区,将它加入对应的 S 2 [ i ] S_2[i] S2[i]。
此后,遍历 S 1 S_1 S1,并根据 S 2 S_2 S2的分区与参考位置向上、向下检查16个点即可。构建与检查均只需要 O ( n ) O(n) O(n)。
3 算法分析
Divide步骤的拆分和维护是 O ( n ) O(n) O(n)的,Merge步骤也是 O ( n ) O(n) O(n)的,易知消耗的递推公式为:
f n = { 2 f n / 2 + O ( n ) n > 3 O ( 1 ) n ≤ 3 f_n=\left\{ \begin{aligned} &2f_{n/2} + O(n) & n>3\\ &O(1) & n\le 3 \end{aligned} \right. fn={2fn/2+O(n)O(1)n>3n≤3
同二维,整个算法也是 O ( n l o g n ) O(nlogn) O(nlogn)。
4 伪代码
这里贴上伪码,重在理解。源码请移步 文末链接 下一节。
5 Java实现
(算了,我发现手机上我自己都不会去点连接看源码,还是放这里了好了,长就长点)
_(:з」∠)_
Point3.java
package util;
import java.util.Objects;
public class Point3 {
public int idx;
public long x, y, z;
public Point3(long x, long y, long z) {
this.x = x;
this.y = y;
this.z = z;
}
/**
* Check whether the first point is smaller in lexicographical order
*/
public static boolean smaller(Point3 p1, Point3 p2) {
if (p1.x != p2.x) {
return p1.x < p2.x;
} else if (p1.y != p2.y) {
return p1.y < p2.y;
}
return p1.z < p2.z;
}
/**
* Get distance of two points
*/
public static double getDis(Point3 p1, Point3 p2) {
long tmp1 = p1.x - p2.x;
long tmp2 = p1.y - p2.y;
long tmp3 = p1.z - p2.z;
return Math.sqrt(tmp1 * tmp1 + tmp2 * tmp2 + tmp3 * tmp3);
}
@Override
public String toString() {
return "Point{" +
"idx=" + idx +
", x=" + x +
", y=" + y +
", z=" + z +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Point3 point3 = (Point3) o;
return x == point3.x &&
y == point3.y &&
z == point3.z;
}
@Override
public int hashCode() {
return Objects.hash(x, y, z);
}
}
Dim3.java
package solve;
import util.Point3;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import static util.Point3.getDis;
public class Dim3 {
private Point3[] px, pz;
/**
* Solve the problem
* @param points the point set
* @return the point pair with smallest distance
*/
public Point3[] solve(Point3[] points) {
int n = points.length;
/* before searching */
px = new Point3[n];
pz = new Point3[n];
System.arraycopy(points, 0, px, 0, n);
System.arraycopy(points, 0, pz, 0, n);
Arrays.sort(px, Comparator.comparingLong(o -> o.x));
Arrays.sort(pz, Comparator.comparingLong(o -> o.z));
// record the point index in px
for (int i = 0; i < n; i++) {
px[i].idx = i;
}
/* Search the ans */
Point3[] res = find(0, n - 1, pz);
/* output the result in lexicographical order */
if (Point3.smaller(res[0], res[1])) {
return new Point3[]{res[0], res[1]};
} else {
return new Point3[]{res[1], res[0]};
}
}
/**
* Find the pair with smallest distance
* @param x1 the left bound of px to find
* @param x2 the right bound of px to find (include)
* @param pz the arr px[x1:x2] ranked by z
* @return the pair expressed in point index
*/
private Point3[] find(int x1, int x2, Point3[] pz){
switch (x2 - x1 + 1) {
case 2:
return new Point3[]{px[x1], px[x2]};
case 3:
double dis12 = getDis(px[x1], px[x1 + 1]);
double dis23 = getDis(px[x1 + 1], px[x2]);
double dis13 = getDis(px[x1], px[x2]);
if (dis12 < dis23) {
if (dis12 < dis13) {
return new Point3[]{px[x1], px[x1 + 1]};
} else {
return new Point3[]{px[x1], px[x2]};
}
} else {
if (dis23 < dis13) {
return new Point3[]{px[x1 + 1], px[x2]};
} else {
return new Point3[]{px[x1], px[x2]};
}
}
}
/* Generate Qx, Rx, Qz, Rz */
int mi = (x1 + x2) / 2;
int idx1 = 0;
int idx2 = 0;
Point3[] qz = new Point3[mi - x1 + 1];
Point3[] rz = new Point3[x2 - mi];
for (Point3 p : pz) {
if (p.idx <= mi) {
qz[idx1++] = p;
} else {
rz[idx2++] = p;
}
}
/* Search recursively */
Point3[] left = find(x1, mi, qz);
Point3[] right = find(mi + 1, x2, rz);
double dis1 = getDis(left[0], left[1]);
double dis3 = getDis(right[0], right[1]);
double delta = Math.min(dis1, dis3);
/* Find minimum distance in crossing-area pair */
long x = px[mi].x;
// Generate S1 and S2
ArrayList<Point3> s1 = new ArrayList<>(); // record all points in s1
ArrayList<Long> segRef = new ArrayList<>(); // record which two seg a point in S1 should refer to
ArrayList<Integer> sl1 = new ArrayList<>(); // location record in s2[j]
ArrayList<Integer> sl2 = new ArrayList<>(); // location record in s2[j+1]
HashMap<Long, ArrayList<Point3>> s2 = new HashMap<>(); // record all s2[]
construct(pz, x, mi, delta, s1, segRef, sl1, sl2, s2);
Point3[] pairMin = new Point3[2];
double disMin = delta;
for (int i = 0; i < s1.size(); i++) {
Point3 s = s1.get(i);
long segIdx = segRef.get(i);
int loc1 = sl1.get(i);
int loc2 = sl2.get(i);
ArrayList<Point3> seg;
// find in seg1
seg = s2.get(segIdx);
if (!seg.isEmpty()){
// look from bottom to top (both sides)
for (int j = Math.max(0, loc1 - 16); j < seg.size() && j <= loc1 + 16; j++) {
Point3 st = seg.get(j);
double tmp = getDis(s, st);
if (tmp < disMin) {
disMin = tmp;
pairMin[0] = s;
pairMin[1] = st;
}
}
}
// find in seg2
seg = s2.get(segIdx + 1);
if (!seg.isEmpty()) {
// look from bottom to top (both sides)
for (int j = Math.max(0, loc2 - 16); j < seg.size() && j <= loc2 + 16; j++) {
Point3 st = seg.get(j);
double tmp = getDis(s, st);
if (tmp < disMin) {
disMin = tmp;
pairMin[0] = s;
pairMin[1] = st;
}
}
}
}
/* Compare and return */
if (pairMin[0] != null) {
return pairMin;
} else if (dis1 < dis3) {
return left;
} else {
return right;
}
}
/**
* Construct S1 and S2
* @param pz the sorted point arr
* @param x the cutting line x = x*
* @param idx the index of px[mi]
* @param delta min(minimum dis in Q, minimum dis in R)
* @param s1 record all points in S1
* @param segRef record which two seg a point in S1 should refer to
* @param sl1 location record in s2[j]
* @param sl2 location record in s2[j+1]
* @param s2 record all s2[] and the points in them
*/
private static void construct(Point3[] pz, long x, int idx, double delta,
ArrayList<Point3> s1,
ArrayList<Long> segRef,
ArrayList<Integer> sl1,
ArrayList<Integer> sl2,
HashMap<Long, ArrayList<Point3>> s2) {
// get the start of y to partition
long yMin = Long.MAX_VALUE;
for (Point3 p : pz) {
if (yMin > p.y) {
yMin = p.y;
}
}
for (Point3 p : pz) {
if (Math.abs(x - p.x) > delta) {
continue;
}
long no = (long)((p.y - yMin) / delta); // in [no*delta:(no+1)*delta] part
if (p.idx <= idx) { // in qx (S1)
s1.add(p);
// build seg ref. save the smaller index of S2 segment
if (no == 0) {
no = -1;
} else {
no = (no - 1) / 2;
}
segRef.add(no);
// build location
ArrayList<Point3> arr;
arr = s2.computeIfAbsent(no, k -> new ArrayList<>());
sl1.add(Math.max(arr.size()-1, 0));
arr = s2.computeIfAbsent(no+1, k -> new ArrayList<>());
sl2.add(Math.max(arr.size()-1, 0));
} else { // in rx (S2)
// add it to correct S2[j]
no = no / 2;
s2.computeIfAbsent(no, k -> new ArrayList<>()).add(p);
}
}
}
}
写在最后:
- 该问题系列的所有实现代码、暴力代码以及对拍/测试代码(Java):戳这里
- 本人初来乍到,这是第一次在CSDN分享博文,如有谬误欢迎指出。