N皇后问题随机搜索(线性冲突检测)
问题:
对于N皇后问题,可以通过回溯法来进行求解,能够找到所有的可行解。随着规模的增大,可以通过随机搜索的方式来快速地获得一个可行的解。
随机搜索的过程中,通过皇后之间的冲突数来衡量结果的好坏。从第一行开始,每个皇后轮询地和后续的皇后进行检查,看是否存在冲突,时间复杂度为O(n^2)。随着时间上升,冲突计算部分是明显的瓶颈。
可以通过其他方法将冲突计算时间复杂度优化为O(n)。网上的方法,现成的代码也不少。但似乎没有太细致地说这件事情,因此有了本文。简单地说一说,希望能够方便读者快速理解。
线性冲突检测
这里通过一个一维数组a来表示皇后的位置,对于a[i],i表示的是第i行,a[i]表示第i行的皇后放置在第a[i]列。
对于a[3]={2,1,3}表示的皇后如下:
通过一维数组定义的情况下,不会出现行冲突和列冲突的情况,只需要的考虑对角线冲突的情况。
表格中白底部分的数值是行和列编号的相加,可以看出来,在一条对角线上的元素行列相加的和都相同。换句话说,对于行列相加值相同的元素,会在一条对角线上。
下面举个例子,a[8]={2,1,6,4,5,6,8,7}(此处下标从1开始)的放置方法如下。行列值为3的皇后两个,行列值为9的皇后有4个,值为15的皇后有2个。同一条对角线上两两冲突,冲突数为(C-1)*C/2,如行列值为3的冲突数为1。
行列值共有2N种情况,扫描a数组,并通过一个长度为2N的数组就能够统计整个a数组的可能行列情况。
上面是反对角线的情况,正对角线也相同。行列值的计算稍作下改动,改为row + n-column +1 。
冲突皇后都统计出来之后,正反对角线数组均扫描一次,就能计算冲突数。
int conflicts(std::vector<int> &vec){
int count=0;
int n=vec.size();
std::vector<int> diag(2*num+1,0); #下标是0开始,但是a[i]值是从1开始的,需要多一个位置
std::vector<int> rDiag(2*num+1,0);
for (int i = 0; i < n; i++) {
diag[n-vec[i] + 1]++;
rDiag[vec[i] + i]++;
}
for (int i = 0; i < diag.size(); i++) {
if (diag[i] > 1) count += diag[i]*(diag[i] - 1)/2;
if (rDiag[i] > 1) count += rDiag[i]*(rDiag[i] - 1)/2;
}
return count;
}
完整代码如下:
#include <vector>
#include <unistd.h>
#include <time.h>
#include <string>
#include <cstdlib>
#include <math.h>
int conflictCount(std::vector<int> &candidate,std::vector<int> diagonalLineVec,std::vector<int> rDiagonalLineVec){
int n=candidate.size();
int count=0;
for (int i = 0; i < n; i++) {
diagonalLineVec[candidate[i] - 1 + n]++;
rDiagonalLineVec[candidate[i] + i]++;
}
for (int i = 0; i < 2 * n; i++) {
if (diagonalLineVec[i] > 1) count += diagonalLineVec[i]*(diagonalLineVec[i] - 1)/2;
if (rDiagonalLineVec[i] > 1) count += rDiagonalLineVec[i]*(rDiagonalLineVec[i] - 1)/2;
}
return count;
}
void queenInit(std::vector<int> &vec){
//配置随机数
unsigned seed;
seed=time(0);
srand(seed);
for(int i=0;i<vec.size();i++)
vec[i]=i+1;
for(int i=0;i<vec.size();i++) //随机交换数组
std::swap(vec[i],vec[rand()%vec.size()]);
}
std::vector<int> queenLocalSearch(unsigned int num){
std::vector<int> tmpResult(num);
std::vector<int> tmp(num);
std::vector<int> diagonalLineVec(num*2+1,0);
std::vector<int> rDiagonalLineVec(num*2+1,0);
queenInit(tmpResult);
int lowestConflictCount=conflictCount(tmpResult,diagonalLineVec,rDiagonalLineVec);
next:
while(lowestConflictCount!=0)
{
for(int i=0;i<num-1;i++)
for(int j=i;j<num;j++)
{
tmp=tmpResult;
std::swap(tmp[i],tmp[j]);
int count=conflictCount(tmp,diagonalLineVec,rDiagonalLineVec);
if(count<lowestConflictCount)
{
lowestConflictCount=count;
tmpResult=tmp;
goto next;
}
}
queenInit(tmpResult);
lowestConflictCount=conflictCount(tmpResult,diagonalLineVec,rDiagonalLineVec);
}
return tmpResult;
}
int main(){
// std::vector<int> testNum{100,150,200,500,1000,1500,3000,5000};
std::vector<int> testNum{50,100,200,500,1000,1500,2000,3000};
for(int i=0;i<testNum.size();i++)
{
clock_t begin=clock();
std::vector<int> vec=queenLocalSearch(testNum[i]);
clock_t end=clock();
printf("Num:%d\n",testNum[i]);
printf("[");
for(int j=0;j<testNum[i]-1;j++)
printf("%d,",vec[j]);
#unix系统的clock()返回的单位是微妙,若是unix系统环境运行,消耗时间计算还需做一下转化
printf("%d]\nElasped Time:%d ms\n\n",vec[testNum[i]-1],(end-begin));
}
return 0;
}