最邻近点对问题(Closest-Pair Problem):三维的分治解法详解

这是该系列的第三篇。一维、二维的最邻近点对问题(Closest-Pair Problem):戳这里
PS:建议先快速浏览低维的分治解法,因为各维度的解决思路具有高度的关联性。如果还没有弄清二维解法仍要执意往下看,最好先给头发买好保险。


1 问题描述

对于三维点 p i = ( x i 1 , x i 2 , x i 3 ) p_i = (x_{i1}, x_{i2}, x_{i3}) pi=(xi1,xi2,xi3),任意两点的欧几里得距离为 ∑ k = 1 3 ( x i k − x j k ) 2 \sqrt{\sum_{k=1}^3(x_{ik} - x_{jk})^2} k=13(xikxjk)2

暴力算法与二维完全一致,枚举所有可能的点对,计算距离并选择最小距离的点对返回。它的时间复杂度是 O ( n 2 ) O(n^2) O(n2)

同样,借助分治思想可以把复杂度降到 O ( n l o g n ) O(nlogn) O(nlogn)

2 算法描述

2.1 开始前

首先对于一个乱序的点集 P P P,我们需要分别对 x x x z z z进行排序,得到 P x P_x Px P z P_z Pz。前者用于divide,后者用于merge。不对 y y y进行排序的原因是,我们将通过切片来减少这个维度上的重复工作。

2.2 Divide步骤

这一步与二维情况完全一致,只不过按 y y y的情况变为了按 z z z如理解二维的Divide,可跳过。

第一步:拆分

对于任意的点集 P P P,其中点的数量 ∣ P ∣ = n \vert P\vert = n P=n,类似于二维情况,记 ⌊ n / 2 ⌋ \lfloor n/2\rfloor n/2 m i mi mi,可以在 P x P_x Px中第 m i mi mi个点的坐标(记为 x ∗ x^* x)处,将其分为 Q Q Q区与 R R R区。其中坐标小于等于 x ∗ x^* x的点放入 Q Q Q区,大于的点放入 R R R区。

准确来说,分区的标准是点在 P x P_x Px中的位置,小于等于 m i mi mi划入 Q Q Q区,反之 R R R区,不过按 x ∗ x^* x来理解会比较方便。

第二步: 维护有序数组

对于 Q Q Q区和 R R R区,我们想要维护4个有序数组:按 x x x z z z排好序的 Q x Q_x Qx Q z Q_z Qz R x R_x Rx R z R_z Rz。其中,按 x x x排序的数组, Q x Q_x Qx R x R_x Rx,按索引就可以分开了。

Q z Q_z Qz R z R_z Rz,需要新建数组,顺次遍历 P z P_z Pz并判断从属的区域来构建,花费 O ( n ) O(n) O(n)时间,且它们自然就是有序的。注意,如果你想通过 Q x Q_x Qx R x R_x Rx重排序得到,会影响整个算法的复杂度。

现在原问题就变成了两个完全相同的子问题。最终只有小于3个点时,可以直接比较每个点对得出答案。

2.3 Merge步骤

思路与二维是完全一致的。

假设Q区找到了局部最邻近点 { q 1 , q 2 } \{q_1, q_2\} {q1,q2},距离为 δ 1 \delta_1 δ1,同理R区找到了 { r 1 , r 2 } \{r_1, r_2\} {r1,r2},距离为 δ 2 \delta_2 δ2。并且在所有被分割线割离的点对中,即所有点对 { p i , p j } \{p_{i}, p_{j}\} {pi,pj} p i ∈ Q ∧ p j ∈ R p_i \in Q \wedge p_j \in R piQpjR,有一最小距离 δ 3 \delta_3 δ3。现在对于原问题的解,有完全相同于一维的3种情况:

  • 最小距离的点对在被分割线割离的点对集合中,即 δ 3 \delta_3 δ3是最小,应返回对应被割离的点对。
  • Q Q Q区中找到的点对距离最小,应返回 { q 1 , q 2 } \{q_1, q_2\} {q1,q2}
  • R R R区中找到的点对距离最小,应返回 { r 1 , r 2 } \{r_1, r_2\} {r1,r2}

算法的关键点仍然是找出 δ 3 \delta_3 δ3温馨提示:实现下面的关键算法时,请备好一杯茶或者是一包烟。

2.3.1 构建两个 S S S

δ = m i n ( δ 1 , δ 2 ) \delta = min(\delta_1, \delta_2) δ=min(δ1,δ2),把 x ∗ x^* x左右宽为 δ \delta δ的区域记为 S S S区。其中,左半部分是 S 1 S_1 S1区,有半部分是 S 2 S_2 S2区。我们只需要检查 S S S内左右半区的点的组合即可。那么,是否可以利用 P z P_z Pz像二维那样只用检查一部分点对呢?

或许你已经猜到了,多增加的一个维度让二维下的这种算法完全失效。考虑 y y y方向这个维度,仍有可能需要检查所有点对。我们必须对y再次进行分割。

2.3.2 对 y y y进一步分区

如下图所示,我们统计 y y y方向上覆盖的范围,并错位分区(大的红色分区,看图应该就很好理解了吧),然后编号。很容易发现,处于 S 1 [ i ] S_1[i] S1[i]内的一个点,只需要与 S 2 [ i − 1 ] S_2[i-1] S2[i1] S 2 [ i ] S_2[i] S2[i]两个区域内的点配对并进行比较。反过来也有这样的一对二的关系。超出的话,仅 y y y方向就会大于 δ \delta δ

接下来我们就可以利用鸽巢原理(Pigeonhole Principle)来减少 z z z维度上的检查次数。假设现在要检查点 s s s , 它 属 于 S 1 ,它属于S_1 S1区。以它为起始高度取出对应子区间内高度为 δ \delta δ的长方体(如上图),将它分解为16个边长为 δ / 2 \delta/2 δ/2的正方体box,如下图所示:

那么每个box内最大距离 3 δ / 2 < δ \sqrt{3}\delta/2<\delta 3 δ/2<δ,即每个box内之多存在一点。那么,对于点 s s s,仅需要检查对应区域中,比它稍大的16个点之间的距离即可。超出16个点后,距离一定会超过这块区域的高度 δ \delta δ,因而可以不用考虑。(这里还需要检查比它稍小的16点,即双向检查,见下)

2.3.3 位置记录

理解算法后,实现时却发现:在分离的 S 1 S_1 S1 S 2 S_2 S2区中,我怎么知道在另一个区域里,比点 s s s稍高的点是从哪里开始的?所以我们还需要记录两个参考位置。在 S 1 [ i ] S_1[i] S1[i]区中的点 s s s,记录它在 S 2 [ i − 1 ] S_2[i-1] S2[i1] S 2 [ i ] S_2[i] S2[i]中的参考位置,这样在寻找16个点时,就知道起始位置了。

2.3.4 单向检查与双向检查

二维中仅需要单向检查的原因是,遍历点时并未区分 S 1 S_1 S1 S 2 S_2 S2,每个可能的点对都会被检查到。三维有分 S 1 S_1 S1 S 2 S_2 S2,如果要类比二维实现单边检查,需要在检查 S 2 S_2 S2区中的点时,反向去找对应的 S 1 S_1 S1这意味着位置记录需要存储两个方向

当然,这样实现不是不可以。不过,这里选择更节省空间的做法,只存储 S 1 S_1 S1中的点对应的 S 2 S_2 S2的信息,以及对应时的位置记录。我们仅固定 S 1 S_1 S1中的点去找匹配的、另一区域中的点,且此时向上查找16个,还要向下查找16个。

确保你理解了:不区分 S 1 S_1 S1 S 2 S_2 S2 → \rightarrow 向上检查即可,区分并仅检查 S 1 S_1 S1中的点 → \rightarrow 需要向上向下双向检查。否则会使整个算法得出不正确的结果!(二维同理)

特例:区分并仅检查 S 1 S_1 S1中的点,单向检查 s ∗ s^* s时,将会漏掉红色的点对。不区分时则不会有任何问题。

2.3.5 总结下具体做法

先确定 y y y的最小值用于分区。从小到大遍历有序的 P z P_z Pz,如果点距 x ∗ x^* x大于 δ \delta δ,它不属于 S S S;否则:

  1. 如果这个点 s s s S 1 S_1 S1区,加入 S 1 S_1 S1数组。算出它属于 S 1 S_1 S1的第 i i i个子段,那么它就应该对应 S 2 [ i − 1 ] S_2[i-1] S2[i1] S 2 [ i ] S_2[i] S2[i]。现在 l e n ( S 2 [ i − 1 ] ) len(S_2[i-1]) len(S2[i1]) l e n ( S 2 [ i ] ) len(S_2[i]) len(S2[i])即是需要记录的参考位置。因为目前这两个区段中的点均低于 s s s
  2. 如果这个点 s s s S 2 S_2 S2区,将它加入对应的 S 2 [ i ] S_2[i] S2[i]

此后,遍历 S 1 S_1 S1,并根据 S 2 S_2 S2的分区与参考位置向上、向下检查16个点即可。构建与检查均只需要 O ( n ) O(n) O(n)

3 算法分析

Divide步骤的拆分和维护是 O ( n ) O(n) O(n)的,Merge步骤也是 O ( n ) O(n) O(n)的,易知消耗的递推公式为:

f n = { 2 f n / 2 + O ( n ) n > 3 O ( 1 ) n ≤ 3 f_n=\left\{ \begin{aligned} &2f_{n/2} + O(n) & n>3\\ &O(1) & n\le 3 \end{aligned} \right. fn={2fn/2+O(n)O(1)n>3n3

同二维,整个算法也是 O ( n l o g n ) O(nlogn) O(nlogn)

4 伪代码

这里贴上伪码,重在理解。源码请移步 文末链接 下一节。

5 Java实现

(算了,我发现手机上我自己都不会去点连接看源码,还是放这里了好了,长就长点)
_(:з」∠)_

Point3.java

package util;

import java.util.Objects;

public class Point3 {
	public int idx;
	public long x, y, z;

	public Point3(long x, long y, long z) {
		this.x = x;
		this.y = y;
		this.z = z;
	}

	/**
	 * Check whether the first point is smaller in lexicographical order
	 */
	public static boolean smaller(Point3 p1, Point3 p2) {
		if (p1.x != p2.x) {
			return p1.x < p2.x;
		} else if (p1.y != p2.y) {
			return p1.y < p2.y;
		}
		return p1.z < p2.z;
	}

	/**
	 * Get distance of two points
	 */
	public static double getDis(Point3 p1, Point3 p2) {
		long tmp1 = p1.x - p2.x;
		long tmp2 = p1.y - p2.y;
		long tmp3 = p1.z - p2.z;
		return Math.sqrt(tmp1 * tmp1 + tmp2 * tmp2 + tmp3 * tmp3);
	}

	@Override
	public String toString() {
		return "Point{" +
				"idx=" + idx +
				", x=" + x +
				", y=" + y +
				", z=" + z +
				'}';
	}

	@Override
	public boolean equals(Object o) {
		if (this == o) return true;
		if (o == null || getClass() != o.getClass()) return false;
		Point3 point3 = (Point3) o;
		return x == point3.x &&
				y == point3.y &&
				z == point3.z;
	}

	@Override
	public int hashCode() {
		return Objects.hash(x, y, z);
	}
}

Dim3.java

package solve;

import util.Point3;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;

import static util.Point3.getDis;

public class Dim3 {
	private Point3[] px, pz;

	/**
	 * Solve the problem
	 * @param points the point set
	 * @return the point pair with smallest distance
	 */
	public Point3[] solve(Point3[] points) {
		int n = points.length;

		/* before searching */

		px = new Point3[n];
		pz = new Point3[n];
		System.arraycopy(points, 0, px, 0, n);
		System.arraycopy(points, 0, pz, 0, n);
		Arrays.sort(px, Comparator.comparingLong(o -> o.x));
		Arrays.sort(pz, Comparator.comparingLong(o -> o.z));

		// record the point index in px
		for (int i = 0; i < n; i++) {
			px[i].idx = i;
		}

		/* Search the ans */

		Point3[] res = find(0, n - 1, pz);

		/* output the result in lexicographical order */

		if (Point3.smaller(res[0], res[1])) {
			return new Point3[]{res[0], res[1]};
		} else {
			return new Point3[]{res[1], res[0]};
		}
	}

	/**
	 * Find the pair with smallest distance
	 * @param x1 the left bound of px to find
	 * @param x2 the right bound of px to find (include)
	 * @param pz the arr px[x1:x2] ranked by z
	 * @return the pair expressed in point index
	 */
	private Point3[] find(int x1, int x2, Point3[] pz){
		switch (x2 - x1 + 1) {
			case 2:
				return new Point3[]{px[x1], px[x2]};
			case 3:
				double dis12 = getDis(px[x1], px[x1 + 1]);
				double dis23 = getDis(px[x1 + 1], px[x2]);
				double dis13 = getDis(px[x1], px[x2]);
				if (dis12 < dis23) {
					if (dis12 < dis13) {
						return new Point3[]{px[x1], px[x1 + 1]};
					} else {
						return new Point3[]{px[x1], px[x2]};
					}
				} else {
					if (dis23 < dis13) {
						return new Point3[]{px[x1 + 1], px[x2]};
					} else {
						return new Point3[]{px[x1], px[x2]};
					}
				}
		}

		/* Generate Qx, Rx, Qz, Rz */

		int mi = (x1 + x2) / 2;
		int idx1 = 0;
		int idx2 = 0;
		Point3[] qz = new Point3[mi - x1 + 1];
		Point3[] rz = new Point3[x2 - mi];

		for (Point3 p : pz) {
			if (p.idx <= mi) {
				qz[idx1++] = p;
			} else {
				rz[idx2++] = p;
			}
		}

		/* Search recursively */

		Point3[] left = find(x1, mi, qz);
		Point3[] right = find(mi + 1, x2, rz);

		double dis1 = getDis(left[0], left[1]);
		double dis3 = getDis(right[0], right[1]);
		double delta = Math.min(dis1, dis3);

		/* Find minimum distance in crossing-area pair */

		long x = px[mi].x;

		// Generate S1 and S2
		ArrayList<Point3> s1 = new ArrayList<>();   // record all points in s1
		ArrayList<Long> segRef = new ArrayList<>(); // record which two seg a point in S1 should refer to
		ArrayList<Integer> sl1 = new ArrayList<>(); // location record in s2[j]
		ArrayList<Integer> sl2 = new ArrayList<>(); // location record in s2[j+1]
		HashMap<Long, ArrayList<Point3>> s2 = new HashMap<>();   // record all s2[]
		construct(pz, x, mi, delta, s1, segRef, sl1, sl2, s2);

		Point3[] pairMin = new Point3[2];
		double disMin = delta;

		for (int i = 0; i < s1.size(); i++) {
			Point3 s = s1.get(i);
			long segIdx = segRef.get(i);
			int loc1 = sl1.get(i);
			int loc2 = sl2.get(i);

			ArrayList<Point3> seg;

			// find in seg1
			seg = s2.get(segIdx);
			if (!seg.isEmpty()){
				// look from bottom to top (both sides)
				for (int j = Math.max(0, loc1 - 16); j < seg.size() && j <= loc1 + 16; j++) {
					Point3 st = seg.get(j);
					double tmp = getDis(s, st);
					if (tmp < disMin) {
						disMin = tmp;
						pairMin[0] = s;
						pairMin[1] = st;
					}
				}
			}

			// find in seg2
			seg = s2.get(segIdx + 1);
			if (!seg.isEmpty()) {
				// look from bottom to top (both sides)
				for (int j = Math.max(0, loc2 - 16); j < seg.size() && j <= loc2 + 16; j++) {
					Point3 st = seg.get(j);
					double tmp = getDis(s, st);
					if (tmp < disMin) {
						disMin = tmp;
						pairMin[0] = s;
						pairMin[1] = st;
					}
				}
			}
		}

		/* Compare and return */

		if (pairMin[0] != null) {
			return pairMin;
		} else if (dis1 < dis3) {
			return left;
		} else {
			return right;
		}
	}

	/**
	 * Construct S1 and S2
	 * @param pz the sorted point arr
	 * @param x the cutting line x = x*
	 * @param idx the index of px[mi]
	 * @param delta min(minimum dis in Q, minimum dis in R)
	 * @param s1 record all points in S1
	 * @param segRef record which two seg a point in S1 should refer to
	 * @param sl1 location record in s2[j]
	 * @param sl2 location record in s2[j+1]
	 * @param s2 record all s2[] and the points in them
	 */
	private static void construct(Point3[] pz, long x, int idx, double delta,
	                              ArrayList<Point3> s1,
	                              ArrayList<Long> segRef,
	                              ArrayList<Integer> sl1,
	                              ArrayList<Integer> sl2,
	                              HashMap<Long, ArrayList<Point3>> s2) {
		// get the start of y to partition
		long yMin = Long.MAX_VALUE;
		for (Point3 p : pz) {
			if (yMin > p.y) {
				yMin = p.y;
			}
		}

		for (Point3 p : pz) {
			if (Math.abs(x - p.x) > delta) {
				continue;
			}

			long no = (long)((p.y - yMin) / delta); // in [no*delta:(no+1)*delta] part
			if (p.idx <= idx) {     // in qx (S1)
				s1.add(p);

				// build seg ref. save the smaller index of S2 segment
				if (no == 0) {
					no = -1;
				} else {
					no = (no - 1) / 2;
				}
				segRef.add(no);

				// build location
				ArrayList<Point3> arr;
				arr = s2.computeIfAbsent(no, k -> new ArrayList<>());
				sl1.add(Math.max(arr.size()-1, 0));
				arr = s2.computeIfAbsent(no+1, k -> new ArrayList<>());
				sl2.add(Math.max(arr.size()-1, 0));

			} else {                // in rx (S2)
				// add it to correct S2[j]
				no = no / 2;
				s2.computeIfAbsent(no, k -> new ArrayList<>()).add(p);
			}
		}
	}
}



写在最后:

  1. 该问题系列的所有实现代码、暴力代码以及对拍/测试代码(Java):戳这里
  2. 本人初来乍到,这是第一次在CSDN分享博文,如有谬误欢迎指出。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值