分治算法的基本思想是:
分(divide):递归求解子问题,即:分解+求解,将问题分解为k个方便求解的小问题。
为什么说是递归求解呢,这里可以看作将一个问题分2个子问题,如果2个子问题还是大,再继续分成4个子问题,直到分解到能方便求解的小问题。也就是说分治算法是含有2个以上的递归运算,只有一个递归的例程不能算做分治算法。
治(conquer):从子问题构建原问题的解。
对于分治,最长用到的复杂度分析情况为:
T(N) = aT(N/b) + O(N^k)
当a=b^k时, T(N)=O(N^k*logN)
比如,非常常见的二分法: T(N) = 2T(N/2) + O(N) ,此时a=2, b=2, k=1 即a=b^k,所以 该算法的复杂度为 O(NlogN)
从另一个角度,反向分析问题,如果我们希望得到一个O(NlogN)的算法,那么就需要保证附加工作为O(N),这是一个非常非常关键的利用分治算法解决问题的入手点!!!
而且这个复杂度也是大多数分治算法问题的情况。 当然还有另外两种>, < 情况,就参考书上的详细讲解吧。
下面以一个例子详细介绍如何应用分治算法。
最近点对问题:给定平面上的N个点,找出距离最近的两个点。
对于该问题,算法过程并不算复杂,但要想编程实现,需要克服不少细节问题。
首先应该实现Point类:
Point.h:
#include <iostream>
class Point
{
public:
Point(double x, double y);
double getx() const
{
return m_x;
}
double gety() const
{
return m_y;
}
friend std::ostream& operator<<(std::ostream& os, const Point& p);
private:
double m_x;
double m_y;
};
Point.cpp:
#include "Point.h"
Point::Point(double x, double y) : m_x(x), m_y(y)
{
}
std::ostream& operator<<(std::ostream& os, const Point& p)
{
os<<"P("<<p.getx()<<","<<p.gety()<<")";
return os;
}
这里有几点要注意:重载<<,友元函数的应用,初始化的位置。不多介绍了,代码很简单,上点心看看就好。
接下来,是解决该问题的具体过程:
1. 最简单的解法就是蛮力法:把每两个点的距离求出来,然后找出最小值即可。
虽然,该算法很简单,但编程实现时,从该笨方法起步会有一个很好的过渡,不至于编码难度过陡。
于是我们需要:计算两点距离的函数Distance, 求最近点对函数FindShortPair, 一组点Point p[]用以测试, 打印点函数PrintPoints用于观察。
代码如下: 上面的代码完成了很多该问题算法的外围工作,最重要的是提供了测试环境。注意const的应用,double,int型的定义。
#include "Point.h"
#include <cmath>
double Distance(const Point& s, const Point& t)
{
double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() );
double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() );
return sqrt( squarex + squarey );
}
void FindShortPair(const Point * p, int num)
{//can only find one of the shortest path
double distance=Distance(p[0], p[1]);
int start = 0;
int end = 1;
for (int i=0; i<num; i++)
{
for (int j=i+1; j<num; j++)
{
if ( Distance(p[i], p[j]) < distance )
{
distance = Distance(p[i], p[j]);
start = i;
end = j;
}
}
}
std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl;
std::cout << "the distance is: " << distance<<std::endl;
}
void PrintPoints(const Point * p, int num)
{
for (int i=0; i<num; i++)
{
std::cout << p[i] << " ";
}
std::cout << std::endl;
}
int main(int argc, const char** argv)
{
Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)};
int size = sizeof(p)/sizeof(p[0]);
PrintPoints(p, size);
FindShortPair(p, size);
return 0;
}
2. 蛮力算法复杂度很明显为O(N^2)不理想。如果是O(N*logN)就好多了,下面介绍分治算法在解决该问题的具体应用过程。
假设平面上的点按x排序好了,这样最多增加O(N*logN),这再整个算法来看并没有增加复杂度级别。
排好序后,可以划一条垂线,把点集分成两半:PL和PR。于是最近点对或者在PL中,或者在PR中,或者PL,PR各有一点。
把三种距离情况定义为dL, dR, dC.
其中dL, dR可以递归求解,于是问题就变为计算dC。 根据上面红色字解释,由于我们希望得到O(N*logN)的解,因此必须能够仅仅多花O(N)的附加工作计算dC。
另s=min(dL, dR). 通过观察能得出结论:如果dC<s,即dC对s有所改进,则只需计算dC。如果dC满足这样的条件,则决定dC的两点必然在分割线的s距离之内,称之为带(strip)
否则不可能满足dC<s, 于是缩小了需要考虑的点的范围。
如果是均匀分布的点集,则能证明出在该带中平均只有O(sqrt(N))个点,(注:书上这么写的,我也不会证,先记下这个理论吧)。因此,对这些点运用蛮力法可以在O(N)时间内完成。
于是过程为:
double FindShortPairDC(const Point* p, int num) //DC代表divide and conquer,分治
{
if (num <= 3) //也许您认为,递归到2个点时,才应该返回距离。但如果为3个点,可能会出现PL有2个点,PR有1个点的情况,这时dR会无法计算,所以3个点就要蛮力计算返回。
return EnumShortestPair(p, num);
mid = (num+1)/2;
dL = FindShortPairDC(p, mid);
dR = FindShortPairDC(p+mid, num-mid);
s=min(dL, dR)
for (i=0; i<stripPointNum; i++)
for (j=i+1; j<stripPointNum; j++)
if (dist(pi, pj) < s)
s = dist(pi, pj);
return s;
}
代码实现:注意其中STL中Sort算法的应用方法。
#include "Point.h"
#include <cmath>
#include <algorithm>
double Distance(const Point& s, const Point& t)
{
double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() );
double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() );
return sqrt( squarex + squarey );
}
bool ComparePoint(const Point& p1, const Point& p2)
{
return (p1.getx() < p2.getx());
}
double EnumShortestPair(const Point * p, int num)
{//can only find one of the shortest path
double distance=Distance(p[0], p[1]);
int start = 0;
int end = 1;
for (int i=0; i<num; i++)
{
for (int j=i+1; j<num; j++)
{
if ( Distance(p[i], p[j]) < distance )
{
distance = Distance(p[i], p[j]);
start = i;
end = j;
}
}
}
return distance;
}
double FindShortPairDC(const Point * p, int num)
{//use divide and conquer algorithm to find the shortest path
double dL, dR, d, midXVal;
if (num < 2)
{
std::cout << "Need to input more than 2 points!"<< std::endl;
exit(1);
}
if (num < 4)
{
return EnumShortestPair(p, num);
}
int mid = 0;
mid = (num+1)/2;
dL = FindShortPairDC(p, mid);
dR = FindShortPairDC(p+mid, (num-mid));
d = dL < dR ? dL : dR;
midXVal = p[mid].getx();
int stripStart = 0;
int stripEnd = num-1;
for (int i=0; i<num-1; i++)
{
if ( (p[i].getx() < midXVal-d) && (p[i+1].getx() >= midXVal-d) )
stripStart = i+1;
if ( (p[i].getx() <= midXVal+d) && (p[i+1].getx() > midXVal+d) )
stripEnd = i;
}
int start = 0;
int end = 1;
for (int i=stripStart; i<stripEnd; i++)
{
for (int j=i+1; j<stripEnd; j++)
{
if ( Distance(p[i], p[j]) < d )
{
d = Distance(p[i], p[j]);
start = i;
end = j;
}
}
}
if (start!=0 || end!=0)
std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl;
std::cout << "the distance is: " << d <<std::endl;
return d;
}
void PrintPoints(const Point * p, int num)
{
for (int i=0; i<num; i++)
{
std::cout << p[i] << " ";
}
std::cout << std::endl;
}
int main(int argc, const char** argv)
{
Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)};
int size = sizeof(p)/sizeof(p[0]);
PrintPoints(p, size);
std::sort(p, p+size, ComparePoint);
PrintPoints(p, size);
FindShortPairDC(p, size);
return 0;
}
这里需要解释一下,对于那条垂线的选取,代码并没有按照x坐标取中值,而是取点集的中间位置点 mid 表示PL点集的个数(包括垂线上点),(num-mid)表示PR点集的个数(可能包括垂线上点)。midXVal 为垂线对应的x值。stripStart -- stripEnd为在带中的点集范围,即p[stripStart]到p[stripEnd]
3. 2中的解法最坏情况复杂度仍会上升至O(N*logN), 为了得到O(N*logN)解法,我们仍然需要进行优化。通过进一步观察,我们发现,在带中的点,若进行按y坐标排序后,如果两个点y坐标相差s,则一定不是最短点对,所以只需求y相差不大于s的点对距离即可。这样得到优化的函数:
double FindShortPairDC(const Point* p, int num)
{
if (num <= 3)
return EnumShortestPair(p, num);
mid = (num+1)/2;
dL = FindShortPairDC(p, mid);
dR = FindShortPairDC(p+mid, num-mid);
s=min(dL, dR)
for (i=0; i<stripPointNum; i++)
for (j=i+1; j<stripPointNum; j++)
{
if (pj.y - pi.y > s)
break;
if (dist(pi, pj) < s)
s = dist(pi, pj);
}
return s;
}
代码实现:需要修改两处。
比较函数,增加y信息比较
bool ComparePoint(const Point& p1, const Point& p2)
{
if( fabs(p1.getx() - p2.getx()) < 0.0000000001)
return (p1.gety() < p2.gety());
return (p1.getx() < p2.getx());
}
求最短距离点对,内部循环信息
for (int i=stripStart; i<stripEnd; i++)
{
for (int j=i+1; j<stripEnd; j++)
{
if (p[j].gety()-p[i].gety() > d)
break;
if ( Distance(p[i], p[j]) < d )
{
d = Distance(p[i], p[j]);
start = i;
end = j;
}
}
}
分析下该解法的复杂度:
对于在带内的构成dC两个点pi,pj,这两点一定在一个s*2s的矩形内。否则y距离差便>s
______
|__|__| s
s s
在左右两个s*s方形区域内,最多有4个点,如果再多,则必然有两个点距离<s, 这与s=min(dL, dR)矛盾。
所以在这样一个矩形内,最多存在8个点,亦即对于一个点pi,最多计算7个点与其距离。所以上面的内层循环
for (int j=i+1; j<stripEnd; j++) 最多执行7次,即该内层循环复杂度为O(1), 所以上面双层循环为O(N),可以用O(N)完成带区域内的点集最近点对查找。这样便满足了整体算法复杂度为O(N*logN)