前段时间人工智能的课介绍到A*算法,于是便去了解了一下,然后试着用这个算法去解决经典的八数码问题,一开始写用了挺久时间的,后来试着把算法的框架抽离出来,编写成一个通用的算法模板,这样子如果以后需要用到A*算法的话就可以利用这个模板进行快速开发了(对于刷OJ的题当然不适合,不过可以适用于平时写一些小游戏之类的东西)。
A*算法的原理就不过多介绍了,网上能找到一大堆,核心就是估价函数 g() 的定义,这个会直接影响搜索的速度,我在代码里使用 C++/Java 的多态性来编写业务无关的算法模板,用一个抽象类来表示搜索树中的状态,A*算法主类直接操纵这个抽象类,然后编写自己业务相关的类去继承这个抽象类并实现其中的所有抽象方法(C++里是纯虚函数),之后调用A*算法主类的 run 函数就能得到一条可行并且是最短的的搜索路径,下面具体看代码:(文末附所有代码的 github 地址)
先看 c++ 部分,毕竟一开始就是用 c++ 来写的
首先是表示状态的抽象基类CState,头文件 state.h:
#ifndef __state_h#define __state_h#include#include
usingstd::vector;classCState
{public:
CState();virtual bool operator < (const CState &) const=0;virtual void checkSomeFields(const CState &) const;virtual vector getNextState() const=0;
vector __getNextState() const; //call the function getNextState and deal with iSteps and pparent
virtual long astar_f() const;virtual long astar_g() const=0; //g函数的值越小,优先级就越高,f()和h()函数类似
virtual long astar_h() const;virtual ~CState();intiSteps;const CState *pparent; //必须指向实实际际存在的值!注意不要指向一个局部变量等!
};#endif
源文件 state.cpp:
#include "state.h"#include
usingstd::for_each;
CState::CState(): iSteps(0), pparent(NULL) {}void CState::checkSomeFields(const CState &) const{}
vector CState::__getNextState() const{
vector nextState =getNextState();
for_each(nextState.begin(), nextState.end(), [this](CState *pstate) {
pstate->iSteps = this->iSteps + 1;
pstate->pparent = this;
});returnnextState;
}long CState::astar_f() const{returniSteps;
}long CState::astar_h() const{return astar_f() +astar_g();
}
CState::~CState() {}
子类只需实现小于运算符,getNextState(),astar_g() 这三个纯虚函数就可以了,另外几个虚函数可以不重写,直接用父类的即可。
然后是A*算法主类 CAstar,头文件 astar.h:
#ifndef __ASTAR_H#define __ASTAR_H#include"state.h"#include
using std::set;classCAstar
{public:
CAstar(const CState &_start, const CState &_end);static set getStateByStartAndSteps(const CState &start, intsteps);voidrun();~CAstar();const CState &m_rStart, &m_rEnd;boolbCanSolve;intiSteps;
vectorvecSolve;longlRunTime;intiTotalStates;private:setpointerWaitToDelete;
};#endif
源文件 astar.cpp:
#include "astar.h"#include"timeval.h"#include"exception.h"#include#include#include#include#include
using std::set;usingstd::queue;usingstd::priority_queue;usingstd::swap;usingstd::max;usingstd::sort;usingstd::function;#define For(i,s,t) for(auto i = (s); i != (t); ++i)CAstar::CAstar(const CState &_start, const CState &_end):
m_rStart(_start), m_rEnd(_end), bCanSolve(false), iSteps(0), vecSolve{},
iTotalStates(0), lRunTime(0), pointerWaitToDelete{}
{
m_rStart.checkSomeFields(m_rEnd);
}
template
structCPointerComp
{bool operator () (const T &pl, const T &pr) const{return *pl < *pr;
}
};set CAstar::getStateByStartAndSteps(const CState &start, intsteps)
{setretSet;set >inSet;
inSet.insert(&start);
queuequeState;
queState.push(&start);while(!queState.empty()) {const CState* const pCurState =queState.front();
queState.pop();if(pCurState->iSteps >steps) {continue;
}if(pCurState->iSteps ==steps) {
retSet.insert(pCurState);continue;
}
auto nextState= pCurState->__getNextState();int len =nextState.size();
For(i,0, len) {if(inSet.find(nextState[i]) ==inSet.end()) {
queState.push(nextState[i]);
inSet.insert(nextState[i]);
}else{deletenextState[i];
}
}
}
inSet.erase(&start);
For(ret_it, retSet.begin(), retSet.end()) {
inSet.erase(*ret_it);
}
For(ins_it, inSet.begin(), inSet.end()) {delete *ins_it;
}returnretSet;
}structpriority_state
{bool operator () (const CState* const lhs, const CState* const rhs) const{return lhs->astar_h() > rhs->astar_h();
}
};voidCAstar::run()
{
CTimeVal _time;set>setState;
setState.insert(&m_rStart);
priority_queue, priority_state>queState;
queState.push(&m_rStart);while(!queState.empty()) {//auto pHeadState = *(setState.find(queState.top()));
auto pHeadState =queState.top();
queState.pop();if(!(*pHeadState < m_rEnd) && !(m_rEnd < *pHeadState)) {
bCanSolve= true;
iSteps= pHeadState->iSteps;
vecSolve.push_back(pHeadState);const CState *lastState = pHeadState->pparent;while(lastState !=NULL) {
vecSolve.push_back(lastState);
lastState= lastState->pparent;
}break;
}
vector nextState = pHeadState->__getNextState();int len =nextState.size();for(int i = 0; i < len; ++i) {
auto state_it=setState.find(nextState[i]);if(state_it ==setState.end()) {
queState.push(nextState[i]);
setState.insert(nextState[i]);
}else{if((*state_it)->astar_f() > nextState[i]->astar_f()) {
pointerWaitToDelete.insert(*state_it); //这一句要放在setState.erase前面,防止迭代器失效
setState.erase(state_it);
setState.insert(nextState[i]);
queState.push(nextState[i]);
}else{deletenextState[i];
}
}
}if(setState.size() > 6000 * 10000) {break;
}
}
iTotalStates=setState.size();
lRunTime=_time.costTime();
setState.erase(&m_rStart);
For(vec_it, vecSolve.begin(), vecSolve.end()) {
setState.erase(*vec_it);
}
For(s_it, setState.begin(), setState.end()) {delete *s_it;
}
}
CAstar::~CAstar()
{
For(vec_it, vecSolve.begin(), vecSolve.end()) {if(*vec_it != &m_rStart && *vec_it != &m_rEnd) {delete *vec_it;
}
}for(const auto &pState: pointerWaitToDelete) {deletepState;
}
}
主搜索函数里是以 广度优先搜索 + 优先队列 来实现A*算法的,因为是用多态来实现,用到了指针,所以有些细节可能写得不是很好看,但是经运行测试过没有明显的bug,cpu和内存的使用均在正常的范围内。
以上两个类就是A*算法的主体框架了,但里面用到了自定义的异常类 CException 和计时类 CTimeVal 等一些工具类,具体代码可以在后面的 github 地址里看到。
然后是业务相关的类,这里首先是八数码问题的类 CChess,头文件 chess.h:
#ifndef __CCHESS_H#define __CCHESS_H#include"state.h"#include#include#include
using std::string;usingstd::vector;usingstd::ostream;class CChess: publicCState
{
friend ostream& operator << (ostream &, const CChess &);static intiLimitNum;public:
CChess(const string &state, int row, int col, const string &standard="");virtual bool operator < (const CState &) const;virtual void checkSomeFields(const CState &) const;const string& getStrState() const;void setStrStandard(const string &);virtual vector getNextState() const;//virtual long astar_f() const;
virtual long astar_g() const;//virtual long astar_h() const;
private:void check_row_col() const;void check_value() const;void check_standard() const;
inlineint countNotMatch() const;
inlineint countLocalNotMatch(int, int) const;private:stringstrState;intiRow, iCol;intiZeroIdx;stringstrStandard;intiNotMatch;public:intiMoveFromLast;static const string directs[5];enumDIRECT
{
UP, DOWN, LEFT, RIGHT, UNKOWN
};void output(ostream &out, const string &colSpace=" ", const string &rowSpace="\n") const;
};#endif
chess.h
源文件 chess.cpp:
#include "chess.h"#include"exception.h"#include#include
usingstd::sort;usingstd::swap;#define For(i,s,t) for(auto i = (s); i != (t); ++i)
int CChess::iLimitNum = 20;const string CChess::directs[5] = {"up", "down", "left", "right", "unkown"};void CChess::check_row_col() const{if(iRow <= 0 || iCol <= 0) {throw CException(1001, "行或列的值不能小于等于0!");
}if(iRow * iCol >iLimitNum) {char msg[100];
sprintf(msg,"行列数的乘积不能超过%d!", iLimitNum);throw CException(1002, msg);
}if(iRow * iCol !=strState.size()) {throw CException(1003, "行列数的乘积应该和字符串的长度相等!");
}
}void CChess::check_value() const{if(iZeroIdx == string::npos) {throw CException(1004, "字符串值不合法,应该含有'0'!");
}bool ch[300];
memset(ch,0, sizeof(ch));int len =strState.size();for(int i = 0; i < len; ++i) {if(ch[strState[i]] == true) {throw CException(1005, "字符串中不能含有相同的字符!");
}
ch[strState[i]]= true;
}
}void CChess::check_standard() const{int len =strState.size();int len2 =strStandard.size();if(len !=len2) {throw CException(1006, "目标状态的字符长度与当前状态的字符长度不等!");
}bool origin[300];
memset(origin,false, sizeoforigin);
For(i,0, len) {
origin[strState[i]]= true;
}bool standard[300];
memset(standard,false, sizeofstandard);
For(i,0, len) {
standard[strStandard[i]]= true;
}
For(i,0, 300) {if(origin[i] !=standard[i]) {throw CException(1007, "目标状态的字符内容与当前状态的字符内容不等!");
}
}
}
CChess::CChess(const string &state, int row, int col, const string &standard):
strState(state), iRow(row), iCol(col), CState(),
iMoveFromLast(UNKOWN), strStandard(standard)
{
check_row_col();
iZeroIdx= strState.find('0');
check_value();if(strStandard == "") {
strStandard=strState;
sort(strStandard.begin(), strStandard.end());
}
check_standard();
iNotMatch=countNotMatch();
}void CChess::checkSomeFields(const CState &rhs) const{if(iRow != ((CChess*)&rhs)->iRow) {throw CException(2001, "开始字符串和结束字符串的行不相同!");
}if(iCol != ((CChess*)&rhs)->iCol) {throw CException(2002, "开始字符串和结束字符串的列不相同!");
}
auto tmp_this=strState;
auto tmp_rhs= ((CChess*)&rhs)->strState;
sort(tmp_this.begin(), tmp_this.end());
sort(tmp_rhs.begin(), tmp_rhs.end());if(tmp_this !=tmp_rhs) {throw CException(2003, "开始字符串和结束字符串含有的字符有差别!");
}
}bool CChess::operator < (const CState &rhs) const{const auto &r_str = ((CChess*)&rhs)->strState;int cmp =strcmp(strState.c_str(), r_str.c_str());const auto &r_row = ((CChess*)&rhs)->iRow;const auto &r_col = ((CChess*)&rhs)->iCol;if(cmp == 0) {if(iRow == r_row) return iCol
}return cmp < 0;
}const string& CChess::getStrState() const{returnstrState;
}void CChess::setStrStandard(const string &standard)
{
strStandard=standard;
check_standard();
iNotMatch=countNotMatch();
}int CChess::countNotMatch() const{int notMatch = 0;
For(i,0, iRow) {
For(j,0, iCol) {if(strState[i * iCol + j] != strStandard[i * iCol +j]) {++notMatch;
}
}
}returnnotMatch;
}int CChess::countLocalNotMatch(int one, int two) const{int oldNotMatch = 0;if(strState[two] !=strStandard[one]) {++oldNotMatch;
}if(strState[one] !=strStandard[two]) {++oldNotMatch;
}int nowNotMatch = 0;if(strState[one] !=strStandard[one]) {++nowNotMatch;
}if(strState[two] !=strStandard[two]) {++nowNotMatch;
}return this->iNotMatch - oldNotMatch +nowNotMatch;
}
vector CChess::getNextState() const{
vectornextChess;//0上面存在数字,可以下移
if(iZeroIdx >=iCol) {
CChess*down = new CChess(*this);
swap(down->strState[iZeroIdx - iCol], down->strState[iZeroIdx]);
down->iNotMatch = down->countLocalNotMatch(iZeroIdx -iCol, iZeroIdx);
down->iZeroIdx -=iCol;
down->iMoveFromLast =CChess::DOWN;
nextChess.push_back(down);
}if(iZeroIdx < strState.size() -iCol) {
CChess*up = new CChess(*this);
swap(up->strState[iZeroIdx + iCol], up->strState[iZeroIdx]);
up->iNotMatch = up->countLocalNotMatch(iZeroIdx +iCol, iZeroIdx);
up->iZeroIdx +=iCol;
up->iMoveFromLast =CChess::UP;
nextChess.push_back(up);
}if(iZeroIdx % iCol != 0) {
CChess*right = new CChess(*this);
swap(right->strState[iZeroIdx - 1], right->strState[iZeroIdx]);
right->iNotMatch = right->countLocalNotMatch(iZeroIdx - 1, iZeroIdx);--right->iZeroIdx;
right->iMoveFromLast =CChess::RIGHT;
nextChess.push_back(right);
}if((iZeroIdx + 1) % iCol != 0) {
CChess*left = new CChess(*this);
swap(left->strState[iZeroIdx + 1], left->strState[iZeroIdx]);
left->iNotMatch = left->countLocalNotMatch(iZeroIdx + 1, iZeroIdx);++left->iZeroIdx;
left->iMoveFromLast =CChess::LEFT;
nextChess.push_back(left);
}returnnextChess;
}long CChess::astar_g() const{returniNotMatch;
}void CChess::output(ostream &out, const string &colSpace, const string &rowSpace) const{for(int i = 0; i < iRow; ++i) {for(int j = 0; j < iCol; ++j) {out << strState[i * iCol +j];if(j != iCol - 1) {out <
}
}if(i != iRow - 1) {out <
}
}
}
std::ostream& operator << (std::ostream &out, const CChess &oChess)
{
oChess.output(out);out << "\n";return out;
}
chess.cpp
八数码问题当时是花了挺久时间做了很大的优化的,最后是main函数,用于简单的交互功能:
#include "chess.h"#include"exception.h"#include"astar.h"#include"timeval.h"#include#include#include
using namespacestd;int main(int argc, char const *argv[])
{stringstr;intr, c;while(true) {try{
cout<< "please input the start state(string) and row, col, separate with a space:\n";if(bool(cin >> str >> r >> c) == false) {break;
}
CChess start(str, r, c);
cout<< "please input the end state(string) and row, col, separate with a space:\n";if(bool(cin >> str >> r >> c) == false) {break;
}
CChess end(str, r, c);
start.setStrStandard(str);
cout<< "Your game is:\n" << start << "--->\n" << end << "\n";
CAstar game(start, end);
game.run();if(game.bCanSolve == true) {
cout<< "your game can be solve:\n";
cout<< "the total steps is:" << game.iSteps << "\n";
cout<< "and the path is:\n";int len =game.vecSolve.size();
cout<< *((CChess*)(game.vecSolve[len - 1])) << "\n";for(int i = len - 2; i >= 0; --i) {
cout<< "|\n";
cout<< "|" << CChess::directs[((CChess*)(game.vecSolve[i]))->iMoveFromLast] << "\n";
cout<< "\\|/\n\n";
cout<< *((CChess*)(game.vecSolve[i])) << "\n";
}
}else{
cout<< "sorry, your game can't be solve, please input the other state.\n\n";
}
cout<< "and the max states is:" << game.iTotalStates << "\n";
cout<< "and the runtime is:" << game.lRunTime << "\n";
}catch (const CException &ex) {
cerr<< ex.code << ":" << ex.msg << "\n";
}catch(...) {break;
}
}return 0;
}
main.cpp
还写了个用于生成测试用例的程序:
#include "astar.h"#include"chess.h"#include"exception.h"#include#include#include#include#include
using namespacestd;void usage(const string &exe_name)
{string echo = "Usage:" + exe_name + "start(string) row(positive int) col(positive int) steps(positive int).";
cout<< echo << "\n";
}
templateT strTo(const string &str)
{
stringstream ss;
ss<
T ret;
ss>>ret;returnret;
}int main(int argc, char const *argv[])
{if(argc < 5) {
usage(argv[0]);
exit(1);
}int row = strTo(argv[2]);int col = strTo(argv[3]);if(!row || !col) {
usage(argv[0]);
exit(2);
}int steps = strTo(argv[4]);string strStandard(argv[1]);bool outputOne = true;if(argc >= 6) {
outputOne= false;
}try{
CChess start(strStandard, row, col);
auto setChess=CAstar::getStateByStartAndSteps(start, steps);if(setChess.size() == 0) {throw CException(3001, "走不了这么多步!");
}if(outputOne == true) {
auto first= *setChess.begin();
((CChess*)first)->output(cout);
cout<< "\t\t" << ((CChess*)first)->getStrState() << "\n";
}else{
cout<< "setChess.size() =" << setChess.size() << "\n";
for_each(setChess.begin(), setChess.end(), [](const CState *elem) {
((CChess*)elem)->output(cout);
cout<< "\t\t" << ((CChess*)elem)->getStrState() << "\n";
});
}
for_each(setChess.begin(), setChess.end(), [](const CState *elem){deleteelem;
});
}catch(const CException &ex) {
cerr<< ex.code << ":" << ex.msg << "\n";
exit(3);
}catch(...) {
cerr<< "unkown error.\n";
exit(4);
}return 0;
}
rand_init.cpp
makefile文件:
CC = g++COMOPT= -std=c++11INCLUDEDIR= -I./tools
LIBDIR= -L./tools
LIBS= -ltools
LINK=$(LIBDIR) $(LIBS)
# OBJS= $(patsubst %.cpp, %.o, $(wildcard *.cpp))
OBJS+=chess.o astar.o state.o
OUTPUT+=game rand_init
all: $(OUTPUT)
game: $(OBJS) main.omake -C tools
$(CC)-o $@ $^$(LINK)
rand_init: $(OBJS) rand_init.omake -C tools
$(CC)-o $@ $^$(LINK)%.o: %.cpp$(CC)-o $@ -c $
clean:make clean -C toolsrm -f *.orm -f $(OUTPUT)
makefile
然后是传教士过河问题,CState 类和 CAstar 类和上面一样,具体的业务实现类 CPersonState 如下:
#ifndef __PERSON_H#define __PERSON_H#include"state.h"#include
usingstd::ostream;class CPersonState: publicCState
{
friend ostream& operator << (ostream&, const CPersonState&);public:
CPersonState();virtual bool operator < (const CState &) const;virtual vector getNextState() const;virtual long astar_g() const;void init(int, int, int);public:static int iTotalMissionary; //the total number of missionaries
static int iTotalSavage; //the total number of savages
static int iBoatCapacity; //the capacity of the boat
private:int iMissionary; //the number of missionaries in the shore where boat anchors
int iSavage; //the number of savages in the shore where boat anchors
int iBoatPosition; //the position of boat, this shore or opposite shore
public:enumPOSITION
{
THIS_SHORE= 1, OPPOSITE_SHORE
};intiMoveMissionary;intiMoveSavage;
};#endif
person.h
#include "person.h"#include"exception.h"#include
usingstd::min;usingstd::max;#define For(i,s,t) for(auto i = (s); i != (t); ++i)
int CPersonState::iTotalMissionary = -1;int CPersonState::iTotalSavage = -1;int CPersonState::iBoatCapacity = -1;
CPersonState::CPersonState() {}void CPersonState::init(int _m, int _s, int_b)
{
iMissionary=_m;
iSavage=_s;
iBoatPosition=_b;
iMoveMissionary= iMoveSavage = 0;if(iTotalMissionary == -1) {throw CException(101, "the total number of missionaries has not been initialized.");
}if(iTotalSavage == -1) {throw CException(102, "the total number of savages has not been initialized.");
}if(iBoatCapacity == -1) {throw CException(103, "the capacity of the boat has not been initialized.");
}if(iMissionary >iTotalMissionary) {throw CException(104, "the number of missionaries on this shore exceeded the total number.");
}if(iSavage >iTotalSavage) {throw CException(105, "the number of savages on this shore exceeded the total number.");
}if(iMissionary && iMissionary
}if(iBoatPosition != CPersonState::THIS_SHORE && iBoatPosition !=CPersonState::OPPOSITE_SHORE) {throw CException(107, "the value of iBoat is invalid, which must be CPersonState::THIS_SHORE \
or CPersonState::OPPOSITE_SHORE, you can use 1 or 2 certainly.");
}
}bool CPersonState::operator < (const CState &rhs) const{const CPersonState* const prhs = (CPersonState*)&rhs;if(iMissionary == prhs->iMissionary) {if(iSavage == prhs->iSavage) {return iBoatPosition < prhs->iBoatPosition;
}return iSavage < prhs->iSavage;
}return iMissionary < prhs->iMissionary;
}usingstd::cin;usingstd::cout;
vector CPersonState::getNextState() const{
vectornextState;int oppo_m = iTotalMissionary -iMissionary;int oppo_s = iTotalSavage -iSavage;int mk =min(iMissionary, iBoatCapacity);int sk =min(iSavage, iBoatCapacity);
For(x,0, mk + 1) {
For(y,0, sk + 1) {if(!x && !y) continue;if(iMissionary - x != 0 && iMissionary - x < iSavage - y) continue;if(x + y > iBoatCapacity || (x && y > x) ) break;if(oppo_m + x != 0 && oppo_m + x < oppo_s + y) break;
CPersonState*_next = newCPersonState();//_next->init(iMissionary - x, iSavage - y, 3 - iBoatPosition);
_next->init(oppo_m + x, oppo_s + y, 3 -iBoatPosition);
_next->iMoveMissionary =x;
_next->iMoveSavage =y;
nextState.push_back(_next);
}
}returnnextState;
}long CPersonState::astar_g() const{int remain_num;//, transport_num;
if(iBoatPosition ==CPersonState::THIS_SHORE) {
remain_num= iMissionary +iSavage;
}else{
remain_num= (iTotalMissionary - iMissionary) + (iTotalSavage -iSavage);
}return remain_num;//+ transport_num;
}
ostream& operator << (ostream &out, const CPersonState &state)
{out << "(" << state.iMissionary << "," << state.iSavage << "," << state.iBoatPosition << ")";return out;
}
person.cpp
main 函数:
#include "person.h"#include"astar.h"#include"exception.h"#include
using namespacestd;int main(int argc, char const *argv[])
{intm,s,k;while(true) {
cout<< "please input the number of missionaries, savages and the capacity of the boat, separate with a space:\n";if(bool(cin >> m >> s >> k) == false) {break;
}
CPersonState::iTotalMissionary=m;
CPersonState::iTotalSavage=s;
CPersonState::iBoatCapacity=k;
CPersonState start, end;try{
start.init(m, s, CPersonState::THIS_SHORE);
end.init(m, s, CPersonState::OPPOSITE_SHORE);
}catch(const CException &ex) {
cerr<< ex.code << ":" << ex.msg << "\n";
}catch(...) {break;
}
CAstar game(start, end);
game.run();if(game.bCanSolve == true) {
cout<< "your game can be solve:\n";
cout<< "the total steps is:" << game.iSteps << "\n";
cout<< "and the path is:\n";int len =game.vecSolve.size();
cout<< *((CPersonState*)(game.vecSolve[len - 1])) << "\n";for(int i = len - 2; i >= 0; --i) {
cout<< "\n |\n";
auto pstate= (CPersonState*)(game.vecSolve[i]);
cout<< "| (" << pstate->iMoveMissionary << "," << pstate->iMoveSavage << ")\n";
cout<< "\\|/\n\n";
cout<< *((CPersonState*)(game.vecSolve[i])) << "\n";
}
}else{
cout<< "sorry, your game can't be solve, please input another state.\n\n";
}
cout<< "the total steps is:" << game.iSteps << "\n";
cout<< "and the max states is:" << game.iTotalStates << "\n";
cout<< "and the runtime is:" << game.lRunTime << "\n";
}return 0;
}
main.cpp
makefile 文件(和上面的相似,只是编译的目标项稍有不同):
CC = g++COMOPT= -std=c++11 -g
INCLUDEDIR= -I./tools
LIBDIR= -L./tools
LIBS= -ltools
LINK=$(LIBDIR) $(LIBS)
# OBJS= $(patsubst %.cpp, %.o, $(wildcard *.cpp))
OBJS+=person.o astar.o state.o
OUTPUT+=across_river
all: $(OUTPUT)
across_river: $(OBJS) main.omake -C tools
$(CC)-o $@ $^$(LINK)%.o: %.cpp$(CC)-o $@ -c $
clean:make clean -C toolsrm -f *.orm -f $(OUTPUT)
makefile
之后我用 Java 来重写,除了面向对象的语法有区别以外其它都几乎是一样的:
首先是自定义异常类 MyException:
packagetools;/*** 自定义的异常类,错误码和错误信息的简单封装*/
public class MyException extendsRuntimeException {private static final long serialVersionUID = 1L;public int code; //错误码
public String msg; //错误信息
public MyException(intcode, String msg) {super();this.code =code;this.msg =msg;
}
@OverridepublicString toString() {return "MyException [code=" + code + ", msg=" + msg + "]";
}
}
MyException.java
计时类 TimeValue:
packagetools;public classTimeValue {private longmilliSecond;/*** 初始化时获取当前系统时间(millisecond)*/
publicTimeValue() {super();this.milliSecond =System.currentTimeMillis();
}public TimeValue(longmilliSecond) {super();this.milliSecond =milliSecond;
}/*** 返回耗时,以毫秒为单位*/
public longcostTime() {long nowMilliSecond =System.currentTimeMillis();return nowMilliSecond - this.milliSecond;
}/*** 重置时间为当前时间*/
public voidreset() {this.milliSecond =System.currentTimeMillis();
}
@OverridepublicString toString() {return "TimeValue [milliSecond=" + milliSecond + "]";
}
}
TimeValue.java
抽象类 State:
packagemain;importjava.util.ArrayList;importtools.MyException;/*** 表示状态的抽象类*/
public abstract classState {public intsteps;publicState parent;publicState() {super();this.steps = 0;this.parent = null;
}abstract public inthashCode();abstract public booleanequals(Object obj);public void checkSomeFields(State rhs) throwsMyException {}abstract public ArrayListgetNextState();public ArrayList__getNextState() {
ArrayList nextState = this.getNextState();for(State st: nextState) {
st.steps= this.steps + 1;
st.parent= this;
}returnnextState;
}public longastar_f() {return this.steps;
}abstract public longastar_g();public longastar_h() {return this.astar_f() + this.astar_g();
}
}
A*算法类 Astar:
packagemain;importjava.util.ArrayList;importjava.util.HashSet;importjava.util.Map;importjava.util.PriorityQueue;importjava.util.Queue;importjava.util.Set;importjava.util.Comparator;importjava.util.HashMap;importtools.MyException;importtools.TimeValue;/*** astar 算法主体类*/
public classAstar {publicState start;publicState end;public booleancanSolve;public intsteps;public ArrayListvecSolve;public longrunTime;public inttotalStates;public Astar(State start, State end) throwsMyException {super();this.start =start;this.end =end;this.canSolve = false;this.steps = 0;this.vecSolve = new ArrayList();this.runTime = 0;this.totalStates = 0;
start.checkSomeFields(end);
}static Set getStateByStartAndSteps(State start, intsteps) {
Set retSet = new HashSet<>();//以后再补充,懒得把c++代码翻译了
returnretSet;
}voidrun() {
TimeValue _time= newTimeValue();
Map mapState = new HashMap<>();
mapState.put(this.start, this.start);//最小堆
Queue queState = new PriorityQueue<>(new Comparator() {
@Overridepublic intcompare(State o1, State o2) {long diff = o1.astar_h() -o2.astar_h();return diff == 0 ? 0: (diff > 0 ? 1: -1);
}
});
queState.add(this.start);while(!queState.isEmpty()) {
State headState=queState.poll();if(headState.equals(this.end)) {this.canSolve = true;this.steps =headState.steps;this.vecSolve.add(headState);
State lastState=headState.parent;while(lastState != null) {this.vecSolve.add(lastState);
lastState=lastState.parent;
}break;
}
ArrayList nextState =headState.__getNextState();for(State _next: nextState) {
State state=mapState.get(_next);if(state == null) {
queState.add(_next);
mapState.put(_next, _next);
}else{if(state.astar_f() >_next.astar_f()) {
mapState.remove(_next);
mapState.put(_next, _next);
queState.add(_next);
}
}
}if(mapState.size() > 3000 * 10000) {break;
}
}this.totalStates =mapState.size();this.runTime =_time.costTime();
}
}
用于表示传教士过河状态的具体类 PersonState:
packagemain;importjava.util.ArrayList;importtools.MyException;public class PersonState extendsState {static public int totalMissionary; //the total number of missionaries
static public int totalSavage; //the total number of savages
static public int boatCapacity; //the capacity of the boat
private int missionary; //the number of missionaries in the shore where boat anchors
private int savage; //the number of savages in the shore where boat anchors
private int boatPosition; //the position of boat, this shore or opposite shore
public static final int THIS_SHORE = 1;public static final int OPPOSITE_SHORE = 2;public intmoveMissionary;public intmoveSavage;publicPersonState() {super();
}public void init(int _m, int _s, int _b) throwsMyException {this.missionary =_m;this.savage =_s;this.boatPosition =_b;this.moveMissionary = moveSavage = 0;if(totalMissionary == -1) {throw new MyException(101, "the total number of missionaries has not been initialized.");
}if(totalSavage == -1) {throw new MyException(102, "the total number of savages has not been initialized.");
}if(boatCapacity == -1) {throw new MyException(103, "the capacity of the boat has not been initialized.");
}if(missionary >totalMissionary) {throw new MyException(104, "the number of missionaries on this shore exceeded the total number.");
}if(savage >totalSavage) {throw new MyException(105, "the number of savages on this shore exceeded the total number.");
}if(missionary != 0 && missionary
}if(boatPosition != THIS_SHORE && boatPosition !=OPPOSITE_SHORE) {throw new MyException(107, "the value of iBoat is invalid, which must be CPersonState::THIS_SHORE or CPersonState::OPPOSITE_SHORE, you can use 1 or 2 certainly.");
}
}
@OverridepublicString toString() {return "PersonState [missionary=" + missionary + ", savage=" + savage + ", boatPosition=" +boatPosition+ ", moveMissionary=" + moveMissionary + ", moveSavage=" + moveSavage + "]";
}
@Overridepublic inthashCode() {final int prime = 31;int result = 1;
result= prime * result +boatPosition;
result= prime * result +missionary;
result= prime * result +moveMissionary;
result= prime * result +moveSavage;
result= prime * result +savage;returnresult;
}
@Overridepublic booleanequals(Object obj) {if (this ==obj)return true;if (obj == null)return false;if (getClass() !=obj.getClass())return false;
PersonState other=(PersonState) obj;if (boatPosition !=other.boatPosition)return false;if (missionary !=other.missionary)return false;if (savage !=other.savage)return false;return true;
}
@Overridepublic ArrayListgetNextState() {
ArrayList nextState = new ArrayList<>();int oppo_m = totalMissionary -missionary;int oppo_s = totalSavage -savage;int mk =Math.min(missionary, boatCapacity);int sk =Math.min(savage, boatCapacity);for(int x = 0; x <= mk; ++x) {for(int y = 0; y <= sk; ++y) {if(x == 0 && y == 0) continue;if(missionary - x != 0 && missionary - x < savage - y) continue;if(x + y > boatCapacity || (x != 0 && y > x) ) break;if(oppo_m + x != 0 && oppo_m + x < oppo_s + y) break;
PersonState _next= newPersonState();
_next.init(oppo_m+ x, oppo_s + y, 3 -boatPosition);
_next.moveMissionary=x;
_next.moveSavage=y;
nextState.add(_next);
}
}returnnextState;
}
@Overridepublic longastar_g() {int remain_num;//, transport_num;
if(boatPosition ==THIS_SHORE) {
remain_num= missionary +savage;
}else{
remain_num= (totalMissionary - missionary) + (totalSavage -savage);
}return remain_num;//+ transport_num;
}
}
PersonState.java
最后是 main 函数,实现简单的交互:
packagemain;importjava.util.Scanner;importtools.MyException;public classMain {public static voidmain(String[] args) {int m = 0, s = 0, k = 0;
Scanner cin= newScanner(System.in);while(true) {
System.out.println("please input the number of missionaries, savages and the capacity of the boat, separate with a space:");try{
m=cin.nextInt();
s=cin.nextInt();
k=cin.nextInt();
}catch(Exception ex) {
System.out.println(ex.toString()+ "\nbye~");break;
}
PersonState.totalMissionary=m;
PersonState.totalSavage=s;
PersonState.boatCapacity=k;
PersonState start= newPersonState();
PersonState end= newPersonState();try{
start.init(m, s, PersonState.THIS_SHORE);
end.init(m, s, PersonState.OPPOSITE_SHORE);
}catch(MyException ex) {
System.out.println(ex.toString());
}catch(Exception ex) {break;
}
Astar game= newAstar(start, end);
game.run();if(game.canSolve == true) {
System.out.println("your game can be solve:");
System.out.println("the total steps is: " +game.steps);
System.out.println("and the path is:\n");int len =game.vecSolve.size();
System.out.println(game.vecSolve.get(len- 1).toString());for(int i = len - 2; i >= 0; --i) {
System.out.println("\n |");
PersonState personState=(PersonState)game.vecSolve.get(i);
System.out.println(" | (" + personState.moveMissionary + ", " + personState.moveSavage + ")");
System.out.println(" \\|/\n");
System.out.println(personState.toString());
}
System.out.println();
}else{
System.out.println("sorry, your game can't be solve, please input another state.\n");
}
System.out.println("the total steps is: " +game.steps);
System.out.println("and the max states is: " +game.totalStates);
System.out.println("and the runtime is: " + game.runTime + "\n");
}
cin.close();
}
}
main.java