回溯的实质是在问题的解空间进行深度优先搜索。DFS是个图的算法,但是回溯算法中的图在哪里呢?我们把解空间中的一个解状态当成一个节点,由于解空间非常庞大,所以这个图也就大到无法想象了。
举个例子吧,比如全排列问题,对于n个元素进行全排列,一共有n!种可能,比如n=9时,一共有9! = 362880种排列。初始化,我们什么都没有,定义如下状态
- #define PT_SIZE 9
- class PTState
- {
- int m_solution[PT_SIZE];
- int m_arranged;
- };
其中m_arranged表示已经在m_solution中排列的数字,刚开始时啥都没做,当然是m_arranged为0啦,而m_solution这个数组是用来放要排列的数字的,比如
123456789对应于m_solution[i] = i + 1
那么接下来要做啥呢?当然是开始排第一个数了,也就是要填m_solution[0]啦,这时你有多少种选择呢?嗯,当然是1, 2, 3, 4, 5 ,6, 7,8,9种,这还要说吗?那么先尝试谁呢?
从小到小依次填好了。
下面这个图是n=3时的解空间搜索图,其中白色的节点是初始节点,一个数都没填,第一行全是问号,第二行有个标记,表示还没有排列一个数。当第一位选1的时,节点变成了紫色,此时m_arranged = 1, 还剩两个数没有填。这时填第二个数时,发现只能填2或者3,当填上2后,又到了下面一个蓝色的节点,该节点的状态中m_arranged = 2,再扩展下去,第三位只能填3,于是到了最下面的黄色节点。
此时,再想扩展黄色节点时,发现m_arranged已经是3了,这时表明这是一个最终的节点,无法再扩展了,于是得到了一个排列。那么其他排列如何得到呢?回到黄色节点上面的蓝色节点,我们发现第三位除了3再也不能填其他数了,没法继续扩展,没办法,再回到上面的紫色节点,这时发现,紫色节点中的第二位还能填3,于是可以顺利扩展到下面的第二个蓝色节点。这个过程便是回溯和扩展节点的主要过程。
在编程实现时,由于我们无法存储所有的状态节点,只能保持一个当前状态,然后通过扩展到下一个节点,来修改当前节点的状态,但是当扩展完一个节点时,必须得恢复到原来节点的状态,才能进行下一个扩展。这个逻辑写成代码便是如下形式:
- Solve (State & state)
- for action in A
- if state.CanMove(action)
- state.Forward(action);
- Solve(state)
- state.Recover(action);
在我们这个全排列的例子中,具体代码如下
- void Solve(PTState& state)
- {
- if (state.IsFinal())
- {
- gCount ++;
- return;
- }
- for (int i = 1; i <= PT_SIZE; i++)
- {
- if (state.CanPlace(i))
- {
- state.Place(i);
- Solve(state);
- state.Remove(i);
- }
- }
- }
其中的state便是上面图中的节点,其中记录了当前已经放置的数的个数,以及具体数的排列。其中,当扩展到黄色节点,也即发现再也无法扩展时,直接返回。这种节点称为终节点,终结点的判断非常简单,只要查一下m_arranged是不是为PT_SIZE(9)就可以了。当到达终节点时,通常是找到了一个解,这里是找到了有一个新的排列,我们只记录一下个数。
- bool IsFinal()
- {
- return m_arranged == PT_SIZE;
- }
调用时,先初始化状态,然后调用Solve(state)即可,其中state的初始化也非常简单,将m_arranged设为0即可。
然后我们具体实现PTState的几个函数,CanPlace, Place,和Remove, 先来看CanPlace(int i),此时当前状态下已排了m_arranged个数,那么当前要排的是第几个数呢?嗯,是第m_arranged个数(从0开始计数),也就是我们要填m_solution[m_arranged], 但是到底m_solution[m_arranged]能不能填 i 呢?根据排列的定义,前面填过的数不能再填了?那我们怎么知道哪些数填过了呢?
哪些数填过了,这个信息其实也可以放在PTState中,作为状态的一部分,对于Solve函数来说完全不可见,在全排列这个例子中,我们可以用一个数组m_used来记录哪些数是否被排过,当i已经被排列过时,m_used[i] = 1,否则为0.
- class PTState
- {
- int m_solution[PT_SIZE];
- int m_arranged;
- int m_used[PT_SIZE + 1];
- };
初始状态时,m_used为全零。 当查询某个数是否可以填时,CanPlace的实现就相当容易了
- bool CanPlace(int i)
- {
- if (m_used[i] == 0)
- {
- return true;
- }
- return false;
- }
但是我们还得维护状态中的这些信息,这个可以在Place和Remove中悄悄地完成,这两个函数都比较简单:
- void Place(int i)
- {
- m_used[i] = 1;
- m_solution[m_arranged++] = i;
- }
- void Remove(int i)
- {
- m_arranged --;
- m_used[i] = 0;
- }
其中,Place修改了当前state的状态,等于是在搜索图中扩展到了下一个状态,而Remove正好相反,是刚扩展的节点完全搜索了之后,回到上一步的状态,这时需要根据扩展时施行的动作,进行反动作,在这里就是把排的数再扔掉。
完整模板和全排列的代码如下
- int gCount = 0;
- void Solve(PTState& state)
- {
- if (state.IsFinal())
- {
- //state.PrintSolution();
- gCount ++;
- return;
- }
- for (int i = 1; i <= PT_SIZE; i++)
- {
- if (state.CanPlace(i))
- {
- state.Place(i);
- Solve(state);
- state.Remove(i);
- }
- }
- }
- #include "stdafx.h"
- #include "..\Utility\GFClock.h"
- // 1-9
- const int PT_SIZE = 10;
- class PTState
- {
- public:
- PTState()
- {
- m_arranged = 0;
- memset(m_used, 0, sizeof(int) * (PT_SIZE + 1));
- }
- bool IsFinal()
- {
- return m_arranged == PT_SIZE;
- }
- void PrintSolution()
- {
- for (int i = 0; i < PT_SIZE ; i++)
- {
- cout << m_solution[i];
- }
- cout << endl;
- }
- bool CanPlace(int i)
- {
- if (m_used[i] == 0)
- {
- return true;
- }
- return false;
- }
- void Place(int i)
- {
- m_used[i] = 1;
- m_solution[m_arranged++] = i;
- }
- void Remove(int i)
- {
- m_arranged --;
- m_used[i] = 0;
- }
- int m_solution[PT_SIZE];
- int m_arranged;
- int m_used[PT_SIZE + 1];
- };
- int gCount = 0;
- void Solve(PTState& state)
- {
- if (state.IsFinal())
- {
- //state.PrintSolution();
- gCount ++;
- return;
- }
- for (int i = 1; i <= PT_SIZE; i++)
- {
- if (state.CanPlace(i))
- {
- state.Place(i);
- Solve(state);
- state.Remove(i);
- }
- }
- }
- int _tmain(int argc, _TCHAR* argv[])
- {
- GFClock clock;
- PTState state;
- Solve(state);
- cout << gCount << endl;
- cout << clock.Elapsed() << endl;
- }
数独(sudoku)想来大家都不会陌生,下面是一个号称非常难的数独,我们看看用回溯算法解决它需要多少时间。
和全排列一样,使用回溯时首先要设计一个状态类,对于数独而言,这个状态就是这个9×9的格子盘,另外,对于每个格子,我们也抽象出来一个Grid类,具体做啥用,下面会提到。
- class Grid
- {
- public:
- Grid(){}
- int val; //当为0时,grid为空格,非0时,val为已填的数字
- int nRemainCount; // 还有几个数可以填
- //记录该grid不能再填的数字
- map<int, int> valMap;
- bool Conflict(int val)
- void IncCount(int _val)
- void DecCount(int _val);
- };
- class SUDOKUState
- {
- public:
- SUDOKUState(int a[TEMPLATE_SIZE][TEMPLATE_SIZE]);
- bool CanPlace(int val);
- void RemoveNumber(int val)
- void PlaceNumber(int val);
- bool IsFinal();
- Grid m_grids[TEMPLATE_SIZE][TEMPLATE_SIZE];
- std::stack<pair<int,int> > posTrace; //记录放置的位置记录
- int m_curX;//当前要放置数字的空格位置
- int m_curY;
- int m_nRemained; //还有多少个要放置
- bool IsDead();
- void DecideNextPlace()};
回溯模板还是和全排列差不多
- void Solve(SUDOKUState& state)
- {
- if (bSolved)
- {
- return;
- }
- //cout << gCount++ << endl;
- if (state.IsFinal())
- {
- state.PrintBoard();
- bSolved = true;
- return;
- }
- if (state.IsDead())
- {
- return;
- }
- for (int i = 1; i < 10; i++)
- {
- if (!state.CanPlace(i))
- {
- continue;
- }
- state.PlaceNumber(i);
- Solve(state);
- state.RemoveNumber(i);
- }
- }
我们先来看一下数独的状态如何扩展。
初始状态时,已经填了17个格子,那么还有m_nRemained = 81 - 17 = 64个格子没填,m_nRemained这个变量为0时说明状态节点已经是终节点,也即找到了一个数独的解,这里我们不需要把所有解都输出来,所以到找到一个解时,可以设定一个全局参数bSolved为true, 其他节点再扩展时直接返回。
扩展节点时,我们犯愁了,数独未填的格子中我们究竟选哪个填呢?嗯,最简单的做法是随机选一个空格,然后看看这个空格可以填哪些数,比如第二行第七列的空格就只能填4或者9,只能扩展两个节点,而第一行第一列的空格可以填3,4,5,6,8,9六个数。
于是问题就来了,如果我们随机选择填的空格,倘若该空格的候选数字比较多,那么待扩展的节点也会比较多,搜索空间会大很多。这种情况下能不能找到解呢?答案是可以的,但是也许要跑好几天,我一开始试了一下随机选择要填的空格,结果递归调用了1000多万次都没见半点要结束的样子。
这时需要引入一个启发式的方法,即下一步要选择哪个空格填数,按照数独玩家的经验,当然是填候选填数最少的那个空格,比如第二行第7列那个,只有2个备选数字4和7, 状态中的函数DecideNextPlace即是这一贪心方法的实现:
- void DecideNextPlace()
- {
- if (m_nRemained == 0)
- {
- return;
- }
- int minv = 10000;
- for (int r = 0; r < TEMPLATE_SIZE; r++)
- {
- for (int c = 0; c < TEMPLATE_SIZE; c++)
- {
- if (m_grids[r][c].val == 0 && m_grids[r][c].nRemainCount < minv)
- {
- minv = m_grids[r][c].nRemainCount;
- m_curX = c;
- m_curY = r;
- }
- }
- }
- //有一个空格子已经没有数可以选择了
- if (minv == 0)
- {
- m_curX = -1;
- m_curY = -1;
- }
- }
这个方法比较简单,遍历所有空格,看看哪个空格还能用的数字最少,就选哪个空格,m_curX和m_curY分别记录了选中空格的行列,接下来放数字就放在这个格子里头了。
为了快速地获取每个格子的备选数字,我在抽象出来的Grid类中维护了一个map, 用来统计该格子g的同行同列以及同section(3×3的那个子区域)的每个数字的出现次数,显然
如果某个数字的出现次数大于1,那么这个数字就不能在g中出现了,反之可以出现。于是CanPlace可以调用当前要放的空格g的Conflict方法:
- bool Conflict(int val)
- {
- return valMap.find(val) != valMap.end();
- }
- void PlaceNumber(int val)
- {
- m_grids[m_curY][m_curX].val = val;
- m_grids[m_curY][m_curX].IncCount(val);
- for (int r = 0;r < TEMPLATE_SIZE;r++)
- {
- if (r != m_curY)
- {
- m_grids[r][m_curX].IncCount(val);
- }
- }
- for (int c = 0 ; c < TEMPLATE_SIZE; c++)
- {
- if (c != m_curX)
- {
- m_grids[m_curY][c].IncCount(val);
- }
- }
- for (int r = (m_curY / 3) * 3; r < (m_curY / 3) * 3 + 3; r ++)
- {
- for (int c = (m_curX / 3) * 3; c < (m_curX / 3) * 3 + 3; c++)
- {
- if (r == m_curY || c == m_curX)
- {
- continue;
- }
- m_grids[r][c].IncCount(val);
- }
- }
- m_nRemained --;
- posTrace.push(make_pair<int,int>(m_curX, m_curY));
- DecideNextPlace();
- }
这里受影响的格子都调用自身的IncCount方法,表示val这个数的出现又加1了。
- void IncCount(int _val)
- {
- valMap[_val]++;
- nRemainCount = 9 - valMap.size();
- }
- void RemoveNumber(int val)
- {
- assert(!posTrace.empty());
- pair<int, int> prePos = posTrace.top();
- m_curX = prePos.first;
- m_curY = prePos.second;
- m_grids[m_curY][m_curX].val = 0;
- m_grids[m_curY][m_curX].DecCount(val);
- for (int r = 0;r < TEMPLATE_SIZE;r++)
- {
- if (r != m_curY)
- {
- m_grids[r][m_curX].DecCount(val);
- }
- }
- for (int c = 0 ; c < TEMPLATE_SIZE; c++)
- {
- if (c != m_curX)
- {
- m_grids[m_curY][c].DecCount(val);
- }
- }
- for (int r = (m_curY / 3) * 3; r < (m_curY / 3) * 3 + 3; r ++)
- {
- for (int c = (m_curX / 3) * 3; c < (m_curX / 3) * 3 + 3; c++)
- {
- if (r == m_curY || c == m_curX)
- {
- continue;
- }
- m_grids[r][c].DecCount(val);
- }
- }
- posTrace.pop();
- m_nRemained ++;
- }
最后需要说明的是,SUDOKUState中有个IsDead方法
- bool IsDead()
- {
- return m_curY == -1 && m_curY == -1;
- }
这个程序非常快,一共调用了10374次Solve, vs2005 release下只花了52ms
完整代码如下:
- //#include "stdafx.h"
- #include <map>
- #include <stack>
- #include <cassert>
- //#include "..\Utility\GFClock.h"
- using namespace std;
- const int TEMPLATE_SIZE = 9;
- const int SUB_SIZE = 3;
- bool bSolved = false;
- int gCount = 0;
- class Grid
- {
- public:
- Grid(){}
- int val;
- int nRemainCount; // 还有几个数可以填
- //记录该grid不能再填的数字
- map<int, int> valMap;
- bool Conflict(int val)
- {
- return valMap.find(val) != valMap.end();
- }
- //自身或其他地方填了数字影响了当前格子可以选择的数字集合
- void IncCount(int _val)
- {
- valMap[_val]++;
- nRemainCount = 9 - valMap.size();
- }
- void DecCount(int _val)
- {
- valMap[_val]--;
- if (valMap[_val] == 0)
- {
- valMap.erase(_val);
- }
- nRemainCount = 9 - valMap.size();
- }
- };
- int board[TEMPLATE_SIZE][TEMPLATE_SIZE] = {
- {0, 0, 0, 0, 0, 0, 0, 1, 2},
- {0, 0, 0, 0, 3, 5, 0, 0, 0},
- {0, 0, 0, 6, 0, 0, 0, 7, 0},
- {7, 0, 0, 0, 0, 0, 3, 0, 0},
- {0, 0, 0, 4, 0, 0, 8, 0, 0},
- {1, 0, 0, 0, 0, 0, 0, 0, 0},
- {0, 0, 0, 1, 2, 0, 0, 0, 0},
- {0, 8, 0, 0, 0, 0, 0, 4, 0},
- {0, 5, 0, 0, 0, 0, 6, 0, 0}
- };
- class SUDOKUState
- {
- public:
- bool CanPlace(int val)
- {
- return !m_grids[m_curY][m_curX].Conflict(val);
- }
- bool IsFinal()
- {
- return m_nRemained == 0;
- }
- bool IsDead()
- {
- return m_curY == -1 && m_curY == -1;
- }
- //将第y行,第x列的数字挪去
- void RemoveNumber(int val)
- {
- assert(!posTrace.empty());
- pair<int, int> prePos = posTrace.top();
- m_curX = prePos.first;
- m_curY = prePos.second;
- m_grids[m_curY][m_curX].val = 0;
- m_grids[m_curY][m_curX].DecCount(val);
- for (int r = 0;r < TEMPLATE_SIZE;r++)
- {
- if (r != m_curY)
- {
- m_grids[r][m_curX].DecCount(val);
- }
- }
- for (int c = 0 ; c < TEMPLATE_SIZE; c++)
- {
- if (c != m_curX)
- {
- m_grids[m_curY][c].DecCount(val);
- }
- }
- for (int r = (m_curY / 3) * 3; r < (m_curY / 3) * 3 + 3; r ++)
- {
- for (int c = (m_curX / 3) * 3; c < (m_curX / 3) * 3 + 3; c++)
- {
- if (r == m_curY || c == m_curX)
- {
- continue;
- }
- m_grids[r][c].DecCount(val);
- }
- }
- posTrace.pop();
- m_nRemained ++;
- }
- void PlaceNumber(int val)
- {
- m_grids[m_curY][m_curX].val = val;
- m_grids[m_curY][m_curX].IncCount(val);
- for (int r = 0;r < TEMPLATE_SIZE;r++)
- {
- if (r != m_curY)
- {
- m_grids[r][m_curX].IncCount(val);
- }
- }
- for (int c = 0 ; c < TEMPLATE_SIZE; c++)
- {
- if (c != m_curX)
- {
- m_grids[m_curY][c].IncCount(val);
- }
- }
- for (int r = (m_curY / 3) * 3; r < (m_curY / 3) * 3 + 3; r ++)
- {
- for (int c = (m_curX / 3) * 3; c < (m_curX / 3) * 3 + 3; c++)
- {
- if (r == m_curY || c == m_curX)
- {
- continue;
- }
- m_grids[r][c].IncCount(val);
- }
- }
- m_nRemained --;
- posTrace.push(make_pair<int,int>(m_curX, m_curY));
- DecideNextPlace();
- }
- SUDOKUState(int a[TEMPLATE_SIZE][TEMPLATE_SIZE])
- {
- m_nRemained = TEMPLATE_SIZE * TEMPLATE_SIZE;
- for (int i = 0; i <TEMPLATE_SIZE;i++)
- {
- for (int j = 0; j< TEMPLATE_SIZE;j++)
- {
- m_grids[i][j].nRemainCount = TEMPLATE_SIZE;
- m_grids[i][j].val = a[i][j];
- if (a[i][j] != 0)
- {
- //在第i行第j列放了数字a[i][j]
- m_curX = j;
- m_curY = i;
- PlaceNumber(a[i][j]);
- }
- }
- }
- DecideNextPlace();
- }
- //计算下一步放数字的格子,贪心
- void DecideNextPlace()
- {
- if (m_nRemained == 0)
- {
- return;
- }
- int minv = 10000;
- for (int r = 0; r < TEMPLATE_SIZE; r++)
- {
- for (int c = 0; c < TEMPLATE_SIZE; c++)
- {
- if (m_grids[r][c].val == 0 && m_grids[r][c].nRemainCount < minv)
- {
- minv = m_grids[r][c].nRemainCount;
- m_curX = c;
- m_curY = r;
- }
- }
- }
- //有一个空格子已经没有数可以选择了
- if (minv == 0)
- {
- m_curX = -1;
- m_curY = -1;
- }
- }
- void PrintBoard()
- {
- for (int i = 0; i <TEMPLATE_SIZE;i++)
- {
- for (int j = 0; j< TEMPLATE_SIZE;j++)
- {
- cout << m_grids[i][j].val << " ";
- }
- cout << endl;
- }
- }
- Grid m_grids[TEMPLATE_SIZE][TEMPLATE_SIZE];
- std::stack<pair<int,int> > posTrace; //记录放置的位置记录
- int m_nRemained; //还有多少个要放置
- //当前需要放置的位置
- int m_curX;
- int m_curY;
- };
- void Solve(SUDOKUState& state)
- {
- if (bSolved)
- {
- return;
- }
- //cout << gCount++ << endl;
- gCount++;
- if (state.IsFinal())
- {
- state.PrintBoard();
- bSolved = true;
- return;
- }
- if (state.IsDead())
- {
- return;
- }
- for (int i = 1; i < 10; i++)
- {
- if (!state.CanPlace(i))
- {
- continue;
- }
- state.PlaceNumber(i);
- Solve(state);
- state.RemoveNumber(i);
- }
- }
- int _tmain(int argc, _TCHAR* argv[])
- {
- SUDOKUState state(board);
- state.PrintBoard();
- cout << endl;
- //GFClock gfClock;
- Solve(state);
- //cout << gfClock.Elapsed() << " ms" << endl;
- cout << gCount << " times called" << endl;
- return 0;
- }