分治算法。顾名思义,把一个大问题分成多个子问题,通常是规模更小的问题。一直分解到足够小足够直接求解。这样小的规模我们称为基本情况。
运用分治算法的必要条件:
1)问题分解后的子问题和原问题具有求解相似性。通常只是规模有差距,也就是说可以直接递归调用自身求解。
2)问题规模到足够小后可以直接求解,即存在递归尽头/基本情况。这个时间通常是常数的。
3)得到多个子问题的解后可以在某个不太大的时间内求出原大问题的解(最难的一步)。这个不太大的时间通常和我们希望的复杂度有关,我们可以通过主公式来估计它。
因而运用分治算法的三个步骤:
1)分解。2)解决:求解基本情况。3)合并:从多个子问题的解 合并或者说求解出 原问题的解。
这种自顶向下分而治之的求解思路在很多方面可以看得到。来看几个例子。我个人觉得归并排序最好理解,之后写排序算法会仔细介绍,这里讲两个例子,一个课本例子最大子数组问题,一个最近点对问题。
- 算法导论例一,最大子数组问题(有不同版本,课本上为求股票差值最大的两天)。事实上这个问题很容易想到线性时间求解的方案,遍历一遍,最大差值的两天的开始一定是个极小值,遍历时碰到更小的极小值就更新极小值,否则计算差值,并保留最大差值及开始结束的信息,遍历完能得到正确答案。这样讲可能不太直观。直接接上代码,这个问题比较简单没有体现分治的魅力。
#include<stdio.h>
typedef int state;
state getmax(int *A,int n,int *begin,int *end);
int main()
{
int n,i;
scanf("%d",&n);
int A[n];
int be,de;
for(i=0;i<n;i++)
{
scanf("%d",&A[i]);
}
int result = getmax(A,n,&be,&de);
printf("最大差值收益:%d \n第%d天买进 第%d天卖出 \n",result,be,de);
return 0;
}
/*
输入:数组指针,数组个数。
最大数组的初始和结束位置( 第一个为0) 的地址指针
返回值 最大差值
算法:根据每个低点分段,最大组一定在两个低点之间(可证明)
复杂度:O(n)
*/
state getmax(int *A,int n,int *begin,int *end)
{
int i=0;
int b=0; //临时左位置
int be,de;//结果的位置 左右
int min = A[0]; //当前低点
int result = 0; //最大差值
//每个极小值点与下个更小极小值点前的所有点比较
for( i = 1 ; i < n ; i++)
{
if(A[i]>min && A[i]-min>result)
{
result = A[i]-min;
be=b;
de=i;
}
else if(A[i]<min)
{
min = A[i];
b=i;
}
else;
}
*begin=be;
*end = de;
return result;
}
-
而分治算法在O(nlgn)下求解,原理是因为这个问题拆分成一半具有同样的求解特性。这也是分治算法的必要条件。最大子数组问题可以看得出有求解相似性。划分成两个子问题可以直接递归求解。求解的基本情况是小于等于小个的时候。一个则差为无穷大,两个差为差值。 最终要的合并,则是在整个数组中求最大的。最大的可能来自前半个,也可能来自后半个,也可能跨越。跨越的时候不难发现一定是前半个的最小值和后半个的最大值组成的。由此把三个值一比就可得到原问题的解。思路即是如此。时间复杂度分析:T(n)=2T(n/2)+O(n).其中两个子问题求解时间为T(n/2).而每次要返回数组中的最大值或者最小值。求解时间O(n)。
-
再讲个例子。最近点对问题。这个问题比较巧。
问题为:已知平面上分布着点集P中的n个点p1,p2,…pn,点i的坐标记为(xi,yi),1≤i≤n。两点之间的距离取其欧式距离。即 z = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 z= \sqrt{(x1-x2)^{2}+(y1-y2)^{2}} z=(x1−x2)2+(y1−y2)2
问题:找出平面上距离最近的两个点及其距离。(点可重复,距离为0)
分析问题:
如果计算所有点对的距离,然后取最小值,问题复杂度在O(n^2),显然这里不是我们想要的,我们可以通过分治法来得到一个更好的答案。
那么这个问题是否满足分治法要求呢,第一步划分是比较好划分的(有些问题需要通过巧妙的划分才能归纳到原问题上,要仔细观察)。这里不妨以x坐标轴为标准,以x中位数为标准,左边为一部分,右边为一部分。这样得到两个点集。可以用递归求解且基本情况还是点数量小于等于两个时候。故分别可以用T(n/2)的时间求解这两个子问题。
接下来就是由两个子问题的解分析得到原问题的解,现在我们有n个点,要使解答比 n^2 更好,这个合并复杂度不能高于n(由主公式可得,最后给出)。同样的,最近点对可以由左半部分和右半部分给出,也可以是跨越两部分。核心难点在于跨越部分的点的最近距离如何求。看起来似乎要分别匹配n/2个点,复杂度n^2.
要求解跨越两边的点对的最近距离,总数是n个点,要充分利用递归出的结果。先抽丝剥茧,我们可以得到左半部分点对和右半部分点对最近距离的最小值,记为m。跨越部分的点对如果距离大于这个最小值肯定就不是答案了。这样可以在中位数两边各划一个m宽的带,若某点的x坐标距离x中位数超过这个值,则肯定不会成为跨越部分的最近点对。这样问题缩小了吗?
感觉小了,但是要是两个n/2个点全在这个带中间呢,这样还不够,还需要进一步抽丝剥茧,拿掉一些不需要考虑的情况。从y坐标入手,如果y坐标相差超过m,同样不用考虑了,那么对于每个点,有几个点可能和它距离在m以内呢?我们要是有个排好序的y坐标就好了。这样以y坐标遍历一下,对每个点来说,最多可能有7个点和它有可能。常数时间*N。很完美。
为什么最多只有7个呢,对每个点来说,只需要在一个x坐标长2m,y坐标宽m的长方形中考虑,其中若从上到下遍历,则这个点在矩形的上边沿,往下走m宽。(以y坐标遍历不再分x)因为对左边(右边)的点来说,最小距离大于等于m,所以左半边最多4个,同理右半边最多4个,所以最多和7个比。
整理一下思路,我们需要往递归中传两个数组,一个以x排好序,一个以y坐标排序,还要便于划分。排序时间nlgn。递归中由排序数组分隔成子数组需要n时间。然后筛选除掉x坐标大于中位数+m,小于中位数-m的点。时间n。从y坐标遍历,每个点常数时间得与它的最小值,时间也是n。总共时间T(n)=2T(n/2) + O(n).
看一下Python代码实现。
import random
import time
from math import sqrt as G
from math import pow as P
def GetNearest(X,Y):
Num = len(X)
if Num == 2:
# 返回元组第一项为距离,第二项为列表,储存最近的点对
return (round(G(P(X[0][0] - X[1][0], 2)+P(X[0][1] - X[1][1], 2)),2),[(X[0],X[1])])
if Num < 2:
return (float("inf"),[])
XL = X[0:Num//2]
XR = X[Num//2:Num]
#筛选出x坐标符合要求的点对,即划分成左右两部分。注意左右的相符,倒序排列的,所以左边的是大的右边是小等的
YL = list(filter(lambda d: d[0] > X[Num//2][0], Y))
YR = list(filter(lambda d: d[0] <= X[Num // 2][0], Y))
LNearest = GetNearest(XL, YL)
RNearest = GetNearest(XR, YR)
#细节注意,左右两部分的点对距离可能相同,这样的话要把左右算得的点对都加到列表里去,不能遗漏
if LNearest[0] < RNearest[0]:
MinLR = LNearest
elif LNearest[0] == RNearest[0]:
MinLR = LNearest
MinLR[1][len(MinLR[1]):] = RNearest[1]
else:
MinLR = RNearest
DY = list(filter(lambda d: X[Num//2][0]-MinLR[0]<=d[0]<=X[Num//2][0]+MinLR[0], Y))
for i in range(len(DY)): #O(m)
for j in range(i+1,len(DY)):#最多只有7个 所以是常量级别
if DY[i][1] - DY[j][1] > MinLR[0]:#7次之内会跳出
break
else:
MinZ = round(G(P(DY[i][0] - DY[j][0], 2)+P(DY[i][1] - DY[j][1], 2)), 2)
if MinLR[0] > MinZ:
MinLR = (MinZ, [(DY[i], DY[j])])
elif MinLR[0] == MinZ:
# 注意这时候没有区分左右,可能这个点对已经在列表中,并且顺序可能不同
if (DY[i], DY[j]) not in MinLR[1] and (DY[j], DY[i]) not in MinLR[1]:
MinLR[1].append((DY[i], DY[j]))
return MinLR
a = 0
b = 10000
n = int(input("请输入需要生成的点对数量近似值(随机生成后会去重):"))
SetDot = set()
for i in range(n):
x =round((b-a) * random.random(), 1)
y =round((b-a) * random.random(), 1)
SetDot.add((x, y))
#测试用例 SetDot = {(8.0, 3.0), (8.1, 3.1),(3.8, 9.8), (3.7, 9.9),(7.7, 0.9), (7.8, 0.8)}
Dot = list(SetDot)
FTimeS = time.time()
XDot = sorted(Dot, key=lambda d: -d[0])#根据x坐标降序排列 底层是归并排序
YDot = sorted(Dot, key=lambda d: -d[1])#根据y坐标降序排列 底层是归并排序
NearestLen = GetNearest(XDot, YDot)
FTimeE = time.time()
print("分治法结果:", NearestLen[0])
print("分治法得最近点对有:",end=" ")
for i in range(len(NearestLen[1])):
print(NearestLen[1][i], end=" ")
print()
print("分治法用时:", FTimeE - FTimeS)
FTimeS = time.time()
Nearest = (float("inf"),[])
for i in range(len(Dot)):
for j in range(i+1,len(Dot)):
N = round(G(P(Dot[i][0] - Dot[j][0], 2)+P(Dot[i][1] - Dot[j][1], 2)), 2)
if N < Nearest[0]:
Nearest = (N, [(Dot[i], Dot[j])])
elif N == Nearest[0]:#加入都是最近距离的点对
Nearest[1].append((Dot[i], Dot[j]))
FTimeE = time.time()
print("暴力法结果:", Nearest[0])
print("暴力法得最近点对有:",end=" ")
for i in range(len(Nearest[1])):
print(Nearest[1][i],end=" ")
print()
print("暴力法用时:", FTimeE - FTimeS)
print("去重后最终生成了{}个点:".format(len(Dot)))
for i in range(len(Dot)):
print(Dot[i])
可以尝试跑一下,对于数据规模稍微大一点的情况,可以看到分治法比暴力法好很多。
有兴趣的可以去搜搜大整数相乘的分治法求解,和strassen方法求矩阵乘法问题相似,用加减法减少乘法次数来达到减小复杂度的目的,可以动手尝试一下。
最后给出主公式,具体证明算法导论上有。
主定理(定理4.1)
令
a
≥
1
和
b
≥
1
是
常
数
,
f
(
n
)
是
一
个
函
数
。
T
(
n
)
是
定
义
在
非
负
整
数
上
的
递
归
式
:
T
(
n
)
=
a
∗
T
(
n
/
b
)
+
f
(
n
)
令a\geq 1和b\geq 1 是常数,f(n)是一个函数。T(n)是定义在非负整数上的递归式: T(n)=a*T(n/b)+f(n)
令a≥1和b≥1是常数,f(n)是一个函数。T(n)是定义在非负整数上的递归式:T(n)=a∗T(n/b)+f(n)
其中n/b解释为向上取整或者向下取整(不影响)。那么T(n)有如下渐进界:
1.若对某个常数
ξ
>
0
\xi > 0
ξ>0,有
f
(
n
)
=
O
(
n
l
o
g
b
a
−
ξ
)
f(n)=O(n^{log_{b}a-\xi })
f(n)=O(nlogba−ξ)。则
T
(
n
)
=
Θ
(
n
l
o
g
b
a
)
T(n)=\Theta (n^{log_{b}a})
T(n)=Θ(nlogba)
2.若
f
(
n
)
=
O
(
n
l
o
g
b
a
)
f(n)=O(n^{log_{b}a})
f(n)=O(nlogba)。则
T
(
n
)
=
Θ
(
n
l
o
g
b
a
∗
l
g
n
)
T(n)=\Theta (n^{log_{b}a}*lgn)
T(n)=Θ(nlogba∗lgn)
3.若对某个常数
ξ
>
0
\xi > 0
ξ>0,有
f
(
n
)
=
Ω
(
n
l
o
g
b
a
+
ξ
)
f(n)=\Omega (n^{log_{b}a+\xi })
f(n)=Ω(nlogba+ξ),且对某个常数c<1和所有足够大的n有
a
∗
f
(
n
/
b
)
≤
c
∗
f
(
n
)
a*f(n/b)\leq c*f(n)
a∗f(n/b)≤c∗f(n),则
T
(
n
)
=
Θ
(
f
(
n
)
)
T(n)=\Theta (f(n))
T(n)=Θ(f(n))。
即情况1为
f
(
n
)
f(n)
f(n)多项式小于
O
(
n
l
o
g
b
a
)
O(n^{log_{b}a})
O(nlogba),第二种为多项式相等,第三种为多项式大于且等式成立。注意这之间是有间隙的,有使用不了的情况。