如果给定一个“目标状态”,需要求出从初态到目标状态的最小代价,那么优先队列BFS的“优先策略”显然不完善。一个状态的当前代价最小,只能说明从起始状态到当前状态得到代价很小,而在未来的搜索中,从该状态到目标状态可能会花费很大的代价。另外。一些状态虽然虽然当前代价略大,但是未来到目标状态的代价可能会很小,于是从起始状态到目标状态的总代价反而会更优。
为了提高搜索效率,我们很自然想到,可以对未来可能产生的代价进行预估。详细地讲骂我们可以设计一个 ”估价函数“,以任意状态为输入,计算从该状态到目标状态所需代价的估计值。在搜索中,仍维护一个堆,不断从堆中取出 “当前代价+未来估价”最小的状态进行扩展。
这种带有估价函数的优先队列BFS就是A*算法。只要保证对于任意状态state,都有f(state) ≤ \leq ≤g(state),A*算法就一定能在目标状态第一次从堆中被取出时得到最优解,但是搜索过程中每个状态可能被扩展多次,估价f(state)越准确,越接近g(state),A*算法的效率越高。若估价值始终为0,就等价于普通的优先队列BFS。
总结
A* 应用场景:
起点→终点的最短距离
状态空间 >> 1e10
启发函数减小搜索空间
A*算法:
while(q.size())
t ← 优先队列的队头 小根堆
当终点第一次出队时 break;
从起点到当前点的真实距离 d_real
从当前点到终点的估计距离 d_estimate
选择一个两距离之和最小的点
for j in ne[t]:
将邻边入队
A*算法条件:
估计距离<=真实距离
d[state] + f[state] = 起点到state的真实距离 + state到终点的估计距离=估计距离
^
d[state] + g[state] = 起点到state的真实距离 + state到终点的真实距离=真实距离
一定是有解才有 d[i] >= d[最优] = d[u]+f[u]
f[u] >= 0
证明终点第一次出队列即最优解
1 假设终点第一次出队列时不是最优
则说明当前队列中存在点u
有 d[估计]< d[真实]
d[u] + f[u] <= d[u] + g[u]<d[队头终点]
即队列中存在比d[终点]小的值,
2 但我们维护的是一个小根堆,没有比d[队头终点]小的d[u],矛盾
证毕
A* 不用判重
以边权都为1为例
A o→o→o
↑ ↓
S o→o→o→o→o→o→o T
B
dist[A] = dist[S]+1 + f[A] = 7
dist[B] = dist[S]+1 + f[B] = 5
则会优先从B这条路走到T
B走到T后再从A这条路走到T
例题
acwing178.第K短路
首先可以使用数学归纳法得到一个结论:对于任意正整数i和目标节点x,当第i次从堆中取出包含节点x的二元组时,对应的dist值就是从S到x的第k短路。
使用优先队列BFS在最坏条件下时间复杂度为
O
(
K
∗
(
N
+
M
)
∗
l
o
g
(
N
+
M
)
)
O(K*(N+M)*log(N+M))
O(K∗(N+M)∗log(N+M)),这道题给定了起点和终点,可以考虑使用A*提高搜索效率。
我们把估价函数f(x)定为从x到T的最短路长度,这样不但保证了f(x)
≤
\leq
≤g(x),还能顺应g(x)的实际变化趋势。
#include<iostream>
#include<queue>
#include<cstring>
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,pair<int,int>> PIII;
#define MAX_N 1000
#define MAX_M 10000
int n,m;
int h1[MAX_N+5],h2[MAX_N+5];
int e[MAX_M*2+5],ne[MAX_M*2+5],w[MAX_M*2+5];
int v[MAX_N+5],dist[MAX_N+5];
int S,T,K;
int cnt=0,tot=0;
void add(int h[],int a,int b,int c)
{
e[++tot]=b;
w[tot]=c;
ne[tot]=h[a];
h[a]=tot;
}
void dijkstra()
{
memset(dist,0x3f,sizeof dist);
priority_queue<PII,vector<PII>,greater<PII>>heap;
heap.push({0,T});
dist[T]=0;
while(heap.size())
{
PII t=heap.top();
heap.pop();
if(v[t.second])continue;
v[t.second]=1;
for(int i=h2[t.second];i;i=ne[i])
{
if(dist[e[i]]>dist[t.second]+w[i])
{
dist[e[i]]=dist[t.second]+w[i];
heap.push({dist[e[i]],e[i]});
}
}
}
return ;
}
int Astar()
{
priority_queue<PIII,vector<PIII>,greater<PIII>>heap;
heap.push({dist[S],{0,S}});
if(dist[S]==0x3f3f3f3f)return -1;
while(heap.size())
{
PIII t=heap.top();
int distance=t.second.first;
int ver=t.second.second;
heap.pop();
if(ver==T)cnt++;
if(cnt==K)return distance;
for(int i=h1[ver];i;i=ne[i])
{
heap.push({distance+w[i]+dist[e[i]],{distance+w[i],e[i]}});
}
}
return -1;
}
int main()
{
cin>>n>>m;
for(int i=0,a,b,c;i<m;i++)
{
cin>>a>>b>>c;
add(h1,a,b,c);
add(h2,b,a,c);
}
cin>>S>>T>>K;
if(S==T)K++;
dijkstra();
cout<<Astar()<<endl;
return 0;
}
acwing179.八数码
先进行可解性判定。把除空格以外的所有数字排成一个序列,求出该序列的逆序对数。如果初态和终态的逆序对数的奇偶性相同,那么这两个状态可以相互到达,否则一定不能到达。
若问题有解,我们可以使用A*搜索一种移动步数最少的方案,我们把估计函数定为所有数字在state中的对应位置与目标状态end中的位置的曼哈顿距离之和。
#include<iostream>
#include<unordered_map>
#include<queue>
#include<cstring>
using namespace std;
typedef pair<int,string> PIS;
string start,ed,num;
int dx[]={-1,1,0,0},dy[]={0,0,-1,1};
int cnt=0;
string match="udlr";
int f(string s)
{
int ans;
for(int i=0;i<9;i++)
{
if(s[i]=='x')continue;
int t=s[i]-'1';
ans+=abs(i/3-t/3)+abs(i%3-t%3);
}
return ans;
}
string Astar()
{
unordered_map<string,int>dist;
unordered_map<string,pair<char,string>>prev;
priority_queue<PIS,vector<PIS>,greater<PIS>>heap;
heap.push({f(start),start});
dist[start]=0;
while(heap.size())
{
PIS t=heap.top();
heap.pop();
string state=t.second;
if(state==ed)break;
int x,y;
for(int i=0;i<9;i++)
if(state[i]=='x')
{
x=i/3;
y=i%3;
}
string memory=state;
for(int i=0;i<4;i++)
{
int a=x+dx[i],b=y+dy[i];
if(a<0||a>2||b<0||b>2)continue;
swap(state[x*3+y],state[a*3+b]);
if(!dist.count(state)||dist[state]>dist[memory]+1)
{
dist[state]=dist[memory]+1;
prev[state]={match[i],memory};
heap.push({dist[state]+f(state),state});
}
state=memory;
}
}
string ans;
while(ed!=start)
{
ans=prev[ed].first+ans;
ed=prev[ed].second;
}
return ans;
}
int main()
{
char ch;
for(int i=0;i<9;i++)
{
cin>>ch;
start+=ch;
if(ch!='x')num+=ch;
}
ed="12345678x";
for(int i=0;i<8;i++)
for(int j=i+1;j<9;j++)
if(num[i]>num[j])cnt++;
if(cnt&1)cout<<"unsolvable"<<endl;
else cout<<Astar()<<endl;
return 0;
}