问题描述:
如下如所示,a,b分别表示两种状态,每个九宫格中的数字都是0~8共9个数字的任意一种排列,现在要把算出a状态移动到b状态的最佳步骤(移动步数最少)。移动的规则是0方块与上下左右任意一块互换位置。
问题分析:
把每种状态看成一个结点,0块与上下左右块互换位置后形成的状态为之前状态的相邻结点。由此可以见,不过是一个用BFS搜索最短路径的问题。
那么问题来了,由于9和数字的全排列共有9!种情况,面对如此浩大的数字如果搜索的过程不判重,那么运行效率将会极低。问题的关键就是如何判重了。
1. 如何在图的BFS中判断当前九宫格状态是否已经查找过
为了便于表示一种状态,按照如下顺序把数字排成一个int型整数
比如上面的a得到的整数为238016745。为了在计算机中表示该结构,可以使用typedef int State[9];定义一种含有9个元素数组的类型。
方法一:按照整数大小的关系来判断一个数的编码不失为一个不错的选择。换句话说,只需要找到所有比当前整数小的整数的个数,就可以求出当前整数对应的编码。稍加分析我们不难得出求出小于当前整数个数的方法。以238016745为例:
当指针指向下一个位置时
依此类推,不难发现所有小于238016745数的排列总数为
其中val[i]表示当前状态中位置为i对应的数字,m(val[i])表示val[i]的右边的数中小于val[i]的数字个数。
通过算出sum的值,得到当前状态的编号。用一个大小为9!的数组标识某元素是否已经被访问过,假设为int vis[362880];如果vis[n]==0,说明状态n没有被访问过,再vis[n]=1;否则状态n已经被访问过。
方法二:由于方法一中在结点数很大时是开不了这么大的数组的,所以把9!个整数映射为一段整形的区间。我们把由状态对应的整数得到象区间上的一个的函数称为hash函数,该象区间成为hash表。假设所有状态的总数为maxsize,hash表的尺寸为hashsize。哈希表的结构如下所示:
如果每个hash值只对应一种状态,则只需要一个head数组即可表示某状态是否访问过。但实际中是一般的hash函数都有重复,因此,每个head里存放的只映射该hash值的元素的指针,如果用数组保存了九宫格的状态,那么可以用数组下标替代指针,使程序更加简洁。
方法三:使用stl中的set来判断是否有重复的元素
由于该方法中编码不一定得连续,可以直接用该状态整数作为状态的编码。定义set<int> vis;如果vis中没有该整数,那么该状态没有被访问过,再把该整数放入vis中;否则该状态已经被访问过。
2. 关键数据结构
为了表示每个状态,需要State st[362880];
为了便于便于生成邻接结点,需要int dx[]={-1,1,0,0};
int dy[]={0,0,-1,1};
在查找的过程中,可以把队列存储在st的一个片段中,用inthead=1;表示队列的头指针,用intrear=1;表示队列的尾指针。(因为st的第一个元素保存了p中第一元素的父节点下标,所以head应该从1开始)。
为了查找最短路径,用一个距离数组表示dist[362880]表示起始状态到当前状态的最短距离。
对于用hash函数判重的方法,应该有个head[hashsize]表示的个链表的表头指针和next[362880]数组表示某个状态结点的下一状态结点的下标。
为了表示便于回溯打印路径,保存个节点的父节点下标,需要 int p[362880];
3. 打印最短路径
由于BFS到最后是goal状态,因此必须先把路径保存到一个栈中,然后根据p回溯到状态的下标为0,然后从栈顶到栈底打印路径。
解决了以上3个问题,就可以编程实现八数码问题了:
用方法一来判重实现的程序:
#include<iostream>
#include<cstdio>
#include<vector>
using namespacestd;
const intmaxstate = 1000000;
typedef intState[9];
Statest[maxstate];
State goal;
const int dx[] ={ -1, 1, 0, 0 };
const int dy[] ={ 0, 0, -1, 1 };
int head = 1,rear = 2;
int fact[9];
voidinit_tables()
{
fact[0] = 1;
for (int i = 1; i < 9; i++)fact[i] =fact[i - 1] * i;
}
intvis[maxstate];
bool try_to_insert(ints)
{
int sum=0;
for (int i = 0; i < 9; i++){
int cnt = 0;
for (int j = i + 1; j < 9; j++)if(st[s][j] < st[s][i])cnt++;
sum += fact[8 - i] * cnt;
}
if (vis[sum])return false;
else{
vis[sum] = 1;
return true;
}
}
int dist[maxstate];
int p[maxstate];
int bfs()
{
init_tables();
while (head < rear){
State& h = st[head];
if (memcmp(&h, &goal,sizeof(goal)) == 0)return head;
int z;
for (z = 0; z < 9; z++)if(h[z] ==0)break;
int y = z / 3;
int x = z % 3;
for (int i = 0; i < 4; i++){
int newx = x + dx[i];//构造出新状态
int newy = y + dy[i];
if (newx<0 || newx>2 ||newy<0 || newy>2)continue;
int newz = newy * 3 + newx;
State& r = st[rear];
memcpy(&r,&h,sizeof(h));
r[newz] = h[z];
r[z] = h[newz];
if (try_to_insert(rear)){
p[rear] = head;
dist[rear] = dist[head] +1;
rear++;
}
}
head++;
}//while
return 0;
}
vector<int>ans;
void print_ans()
{
int p_id = head;
for (;;){
if (p_id == 0)break;
ans.push_back(p_id);
p_id = p[p_id];
}
for (int i = ans.size() - 1; i >= 0;i--){
for (int j = 0; j < 3; j++){
for (int k = 0; k < 3; k++)
printf("%d",st[ans[i]][j*3+k]);
printf("\n");
}
printf("\n");
}
}
int main()
{
for (int i = 0; i < 9;i++)scanf("%d", &st[1][i]);
for (int j = 0; j < 9;j++)scanf("%d", &goal[j]);
if (bfs()){
print_ans();
printf("%d\n", dist[head]);
}
else printf("-1\n");
return 0;
}
程序输入为:
2 6 4 1 3 7 0 58
8 1 5 7 3 6 4 02
程序输入为:
2 6 4
1 3 7
0 5 8
2 6 4
1 3 7
5 0 8
2 6 4
1 3 7
5 8 0
2 6 4
1 3 0
5 8 7
2 6 4
1 0 3
5 8 7
2 6 4
0 1 3
5 8 7
2 6 4
5 1 3
0 8 7
2 6 4
5 1 3
8 0 7
2 6 4
5 1 3
8 7 0
2 6 4
5 1 0
8 7 3
2 6 0
5 1 4
8 7 3
2 0 6
5 1 4
8 7 3
0 2 6
5 1 4
8 7 3
5 2 6
0 1 4
8 7 3
5 2 6
1 0 4
8 7 3
5 0 6
1 2 4
8 7 3
0 5 6
1 2 4
8 7 3
1 5 6
0 2 4
8 7 3
1 5 6
8 2 4
0 7 3
1 5 6
8 2 4
7 0 3
1 5 6
8 2 4
7 3 0
1 5 6
8 2 0
7 3 4
1 5 6
8 0 2
7 3 4
1 5 6
8 3 2
7 0 4
1 5 6
8 3 2
7 4 0
1 5 6
8 3 0
7 4 2
1 5 0
8 3 6
7 4 2
1 0 5
8 3 6
7 4 2
0 1 5
8 3 6
7 4 2
8 1 5
0 3 6
7 4 2
8 1 5
7 3 6
0 4 2
8 1 5
7 3 6
4 0 2
31
用方法二来判重实现的程序:
#include<iostream>
#include<cstdio>
#include<vector>
using namespacestd;
const intmaxstate = 1000000;
typedef intState[9];
Statest[maxstate];
State goal;
const int dx[] ={ -1, 1, 0, 0 };
const int dy[] ={ 0, 0, -1, 1 };
const inthashsize = 1000003;
int head = 1,rear = 2;
int fact[9];
intmyhead[hashsize];
intmynext[maxstate];
void init_tables()
{
memset(myhead, 0, sizeof(myhead));
}
int myhash(intsrc)
{
int val = 0;
for (int i = 0; i < 9; i++)
val = 10 * val + st[src][i];
return val%hashsize;
}
booltry_to_insert(int s)
{
int hcode = myhash(s);
int u = myhead[hcode];
while (u){
if (memcmp(st[u],st[s],sizeof(st[s])) == 0)
return false;
u = mynext[u];
}
mynext[s] = myhead[hcode];
myhead[hcode] = s;
return true;
}
intdist[maxstate];
int p[maxstate];
int bfs()
{
init_tables();
while (head < rear){
State& h = st[head];
if (memcmp(&h, &goal,sizeof(goal)) == 0)return head;
int z;
for (z = 0; z < 9; z++)if(h[z] ==0)break;
int y = z / 3;
int x = z % 3;
for (int i = 0; i < 4; i++){
int newx = x + dx[i];//构造出新状态
int newy = y + dy[i];
if (newx<0 || newx>2 ||newy<0 || newy>2)continue;
int newz = newy * 3 + newx;
State& r = st[rear];
memcpy(&r,&h,sizeof(h));
r[newz] = h[z];
r[z] = h[newz];
if (try_to_insert(rear)){
p[rear] = head;
dist[rear] = dist[head] +1;
rear++;
}
}
head++;
}//while
return 0;
}
vector<int>ans;
void print_ans()
{
int p_id = head;
for (;;){
if (p_id == 0)break;
ans.push_back(p_id);
p_id = p[p_id];
}
for (int i = ans.size() - 1; i >= 0;i--){
for (int j = 0; j < 3; j++){
for (int k = 0; k < 3; k++)
printf("%d",st[ans[i]][j*3+k]);
printf("\n");
}
printf("\n");
}
}
int main()
{
for (int i = 0; i < 9;i++)scanf("%d", &st[1][i]);
for (int j = 0; j < 9;j++)scanf("%d", &goal[j]);
if (bfs()){
print_ans();
printf("%d\n", dist[head]);
}
else printf("-1\n");
return 0;
}
用方法三判重实现的程序:
#include<iostream>
#include<cstdio>
#include<vector>
#include<set>
using namespacestd;
const intmaxstate = 1000000;
typedef intState[9];
Statest[maxstate];
State goal;
const int dx[] ={ -1, 1, 0, 0 };
const int dy[] ={ 0, 0, -1, 1 };
const inthashsize = 1000003;
int head = 1,rear = 2;
int fact[9];
set<int>vis;
voidinit_tables()
{
vis.clear();
}
booltry_to_insert(int s)
{
int v = 0;
for (int i = 0; i < 9; i++)v = v * 10 +st[s][i];
if (vis.count(v))return false;
vis.insert(v);
return true;
}
intdist[maxstate];
int p[maxstate];
int bfs()
{
init_tables();
while (head < rear){
State& h = st[head];
if (memcmp(&h, &goal,sizeof(goal)) == 0)return head;
int z;
for (z = 0; z < 9; z++)if(h[z] ==0)break;
int y = z / 3;
int x = z % 3;
for (int i = 0; i < 4; i++){
int newx = x + dx[i];//构造出新状态
int newy = y + dy[i];
if (newx<0 || newx>2 ||newy<0 || newy>2)continue;
int newz = newy * 3 + newx;
State&r = st[rear];
memcpy(&r,&h,sizeof(h));
r[newz] = h[z];
r[z] = h[newz];
if (try_to_insert(rear)){
p[rear] = head;
dist[rear] = dist[head] +1;
rear++;
}
}
head++;
}//while
return 0;
}
vector<int>ans;
void print_ans()
{
int p_id = head;
for (;;){
if (p_id == 0)break;
ans.push_back(p_id);
p_id = p[p_id];
}
for (int i = ans.size() - 1; i >= 0;i--){
for (int j = 0; j < 3; j++){
for (int k = 0; k < 3; k++)
printf("%d",st[ans[i]][j*3+k]);
printf("\n");
}
printf("\n");
}
}
int main()
{
for (int i = 0; i < 9;i++)scanf("%d", &st[1][i]);
for (int j = 0; j < 9;j++)scanf("%d", &goal[j]);
if (bfs()){
print_ans();
printf("%d\n", dist[head]);
}
else printf("-1\n");
return 0;
}
从运行的效果来看,该方法运行的时间明显比之前的慢。可以先把STL集合版的程序调试通过后,然后转化为哈希表甚至完美哈希表。