搜索+剪枝(位运算,优化搜索顺序)
搜索不用说,相信你一眼就可以看到是搜索算法,问题是这道题目纯搜索明显是要时间爆炸的,所以我们得剪枝.
优化搜索顺序:很明显,我们肯定是从当前能填合法数字最少的位置开始填数字
位运算:很明显这里面check判定很多,我们必须优化这个check,所以我们可以对于,每一行,每一列, 每一个九宫格,都利用一个九位二进制数保存,当前还有哪些数字可以填写.
lowbit:我们这道题目当前得需要用lowbit运算取出当前可以能填的数字.
166. 数独 - AcWing题库 POJ3074
row[N],col[N],cell[N][N] 都用9位的二进制数表示(初始状态下全为1,代表1-9都可以填),判断在坐标x,y上填某个数是否合法,只需要 row[x]&col[y]&cell[x/3][y/3] ,例如结果是000010001,表示1,5可以填。
优化搜索顺序:每次dfs,从所有可以填的位置选出一个1的个数最少的(即能填的合法数字最少),这样的分支数量最少
预处理ones数组后,可以快速返回一个二进制数中有多少个1
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define endl '\n'
char str[100];
const int N=9,M=1<<N;
int row[N],col[N],cell[N][N];
int mp[M],ones[M];
//ones表示0-2^9里每个数有多少个1,mp[x]代表log2(x),比如map[100] = 2
void init()
{
//初始状态下,9位二进制位都为1,代表1-9都可以填
for(int i=0;i<N;i++)row[i]=col[i]=(1<<N)-1;
for(int i=0;i<3;i++)
{
for(int j=0;j<3;j++)
{
cell[i][j]=(1<<N)-1;
}
}
}
//is_set = true则在x, y填上t, 否则则把x,y处的数字删掉, t 是0-8
void draw(int x,int y,int t,int is_set)
{
if(is_set)str[x*N+y]=t+'1';//t 是0-8,所以要+'1',而不是+'0'
else str[x*N+y]='.';
int v=1<<t;
if(!is_set)v=-v;
row[x]-=v;
col[y]-=v;
cell[x/3][y/3]-=v;
}
//x行y列可以填哪个数字, 最后得到2^i + 2^j..+..,这些i, j就是可以填的数字,最后通过map[2^i]来得到这个数字
int get(int x,int y)
{
return row[x]&col[y]&cell[x/3][y/3];
}
bool dfs(int cnt)
{
if(!cnt)return true;//填完所有数字,则返回
int minv=10,x,y; //最多可以填多少个数字
for(int i=0;i<N;i++)
{
for(int j=0;j<N;j++)
{
if(str[i*N+j]=='.')
{
int state=get(i,j);//可以填的数字状态,是1则表示可以填,如000010001,表示1,5可以填,
if(ones[state]<minv)
{
minv=ones[state];//选一个1的个数最少的,这样的分支数量最少
x=i;y=j;
}
}
}
}
int state=get(x,y);
for(;state;state-=state&-state) //依次做lowbit操作,选择每个分支
{
int t=mp[state&-state];//这个t就是要填充的数字
draw(x,y,t,true);//填这个数字
if(dfs(cnt-1))return true; //这次填充成功,则返回
draw(x,y,t,false);//失败则回溯
}
return false;
}
int main() {
// ios::sync_with_stdio(false);
// cout.tie(0);
for(int i=0;i<N;i++)mp[1<<i]=i; //打表,快速地知道可以哪一个数字
for(int i=0;i<1<<N;i++) //ones记录每个状态有多少个1,用于选择分支少的开始搜索, 其中M = 1 << N
{
for(int j=0;j<N;j++)
{
ones[i]+=i>>j&1;
}
}
while(cin>>str)
{
if(str[0]=='e')break;
init();
int cnt=0;
for(int i=0,k=0;i<N;i++)
{
for(int j=0;j<N;j++,k++)
{
if(str[k]!='.'){
int t=str[k]-'1';
draw(i,j,t,true);
}
else cnt++;
}
}
dfs(cnt);
puts(str);
}
return 0;
}