A ∗ A* A∗算法的思路可看:
Introduction to the A ∗ A* A∗ Algorithm
提炼一下:
定义起点 s s s,终点 t t t,从起点开始的距离函数 g ( x ) g(x) g(x) ,到终点的距离函数 h 1 ( x ) h_{1}(x) h1(x) , h 2 ( x ) h_{2}(x) h2(x),以及每个点的估价函数 f ( x ) = g ( x ) + h 1 ( x ) f(x)=g(x)+h_{1}(x) f(x)=g(x)+h1(x),其中 h 1 ( x ) h_{1}(x) h1(x)是我们定义的点 x x x到终点的预估代价函数, h 2 ( x ) h_{2}(x) h2(x)是点 x x x到终点的实际代价函数
启发函数会影响 A ∗ A* A∗算法的行为:
- 在极端情况下,当启发函数 h 1 ( x ) h_{1}(x) h1(x)始终为0,则将由 g ( x ) g(x) g(x)决定节点的优先级,此时算法就退化成了Dijkstra算法
- 如果 h 1 ( x ) h_{1}(x) h1(x)始终小于等于节点 x x x到终点的代价 h 2 ( x ) h_{2}(x) h2(x),则 A ∗ A* A∗算法保证一定能够找到最短路径。但是当 h 1 ( x ) h_{1}(x) h1(x)的值越小,算法将遍历越多的节点,也就导致算法越慢
- 如果 h 1 ( x ) h_{1}(x) h1(x)完全等于节点 x x x到终点的代价 h 2 ( x ) h_{2}(x) h2(x),则 A ∗ A* A∗算法将找到最佳路径,并且速度很快。可惜的是,并非所有场景下都能做到这一点,因为在没有达到终点之前,我们很难确切算出距离终点还有多远
- 如果 h 1 ( x ) h_{1}(x) h1(x)的值比节点 x x x到终点的代价 h 2 ( x ) h_{2}(x) h2(x)要大,则 A ∗ A* A∗算法不能保证找到最短路径,不过此时会很快
- 在另外一个极端情况下,如果 h 1 ( x ) h_{1}(x) h1(x)相较于 g ( x ) g(x) g(x)大很多,则此时只有 h 1 ( x ) h_{1}(x) h1(x)产生效果,这也就变成了最佳优先搜索
所以通过调节启发函数,我们可以控制算法的速度和精确度,在一些情况,我们可能未必需要最短路径,而是希望能够尽快找到一个路径即可,这也是 A ∗ A* A∗算法比较灵活的地方
例一:
思路:
我们知道,在平面上,坐标 ( x 1 , y 1 ) (x_{1},y_{1}) (x1,y1)的 i i i点与坐标 ( x 2 , y 2 ) (x_{2},y_{2}) (x2,y2)的 j j j点的曼哈顿距离为: d ( i , j ) = ∣ x 1 − x 2 ∣ + ∣ y 2 − y 1 ∣ d(i,j)=|x_{1}-x_{2}|+|y_{2}-y_{1}| d(i,j)=∣x1−x2∣+∣y2−y1∣,所以,用每个数和其最终位置的曼哈顿距离作为 h 1 ( x ) h_{1}(x) h1(x),因为其小于 h 2 ( x ) h_{2}(x) h2(x)且差别又不大,所以可以很好的优化
代码:
#include <iostream>
#include <cstring>
#include <queue>
#include <unordered_map>
#include <algorithm>
using namespace std;
typedef pair<int , string> PIS;
unordered_map<string , int> dist;
unordered_map<string , pair<string , char>> pre;
priority_queue<PIS , vector<PIS> , greater<PIS>> heap;
string ed = "12345678x";
int dx[4] = {-1 , 0 , 1 , 0} , dy[4] = {0 , 1 , 0 , -1};
char op[] = "urdl";
int f(string state){//求估值函数,即曼哈顿距离
int res = 0;
for(int i = 0 ; i < 9 ; i++){
if(state[i] != 'x'){
int t = state[i] - '1';
res += abs(t / 3 - i / 3) + abs(t % 3 - i % 3);
}
}
return res;
}
string bfs(string start){
heap.push({f(start) , start});
dist[start] = 0;
while(heap.size()){
auto t = heap.top();heap.pop();
string state = t.second;
int step = dist[state];//记录到达state的实际距离
if(state == ed) break;//如果到达终点就break
int k = state.find('x');
int x = k / 3 , y = k % 3;
string source = state;//因为在下面state会变,所以留一个备份
for (int i = 0; i < 4; i ++ ){
int a = x + dx[i], b = y + dy[i];
if (a >= 0 && a < 3 && b >= 0 && b < 3){
swap(state[x * 3 + y], state[a * 3 + b]);
if (!dist.count(state) || dist[state] > step + 1){
dist[state] = step + 1;
pre[state] = {source, op[i]};
heap.push({dist[state] + f(state), state});
}
swap(state[x * 3 + y], state[a * 3 + b]);//恢复回来
}
}
}
string res;
while(ed != start){
res += pre[ed].second;
ed = pre[ed].first;
}
reverse(res.begin() , res.end());
return res;
}
int main(){
string start , seq;
for(int i = 0 ; i < 9 ; i++){
char c;
cin >> c;
start += c;
if(c != 'x') seq += c;
}
//判断逆序对的数量,如果为奇数,直接无解
int cnt = 0;
for(int i = 0 ; i < 8 ; i ++)
for(int j = i + 1 ; j < 8 ; j++)
if(seq[i] > seq[j])
cnt++;
if(cnt % 2) puts("unsolvable");
else cout << bfs(start) << endl;
return 0;
}
发现可以优化,一个点其实只需进队列一次就可以了:
#include <iostream>
#include <cstring>
#include <queue>
#include <unordered_map>
#include <algorithm>
using namespace std;
typedef pair<int , string> PIS;
unordered_map<string , int> dist;
unordered_map<string , bool> st;
unordered_map<string , pair<string , char>> pre;
priority_queue<PIS , vector<PIS> , greater<PIS>> heap;
string ed = "12345678x";
int dx[4] = {-1 , 0 , 1 , 0} , dy[4] = {0 , 1 , 0 , -1};
char op[] = "urdl";
int f(string state){//求估值函数,即曼哈顿距离
int res = 0;
for(int i = 0 ; i < 9 ; i++){
if(state[i] != 'x'){
int t = state[i] - '1';
res += abs(t / 3 - i / 3) + abs(t % 3 - i % 3);
}
}
return res;
}
string bfs(string start){
heap.push({f(start) , start});
dist[start] = 0;
while(heap.size()){
auto t = heap.top();heap.pop();
string state = t.second;
if(st[state])continue;
st[state]=true;
int step = dist[state];//记录到达state的实际距离
if(state == ed) break;//如果到达终点就break
int k = state.find('x');
int x = k / 3 , y = k % 3;
string source = state;//因为在下面state会变,所以留一个备份
for (int i = 0; i < 4; i ++ ){
int a = x + dx[i], b = y + dy[i];
if (a >= 0 && a < 3 && b >= 0 && b < 3){
swap(state[x * 3 + y], state[a * 3 + b]);
if (!dist.count(state) || dist[state] > step + 1){
dist[state] = step + 1;
pre[state] = {source, op[i]};
heap.push({dist[state] + f(state), state});
}
swap(state[x * 3 + y], state[a * 3 + b]);//恢复回来
}
}
}
string res;
while(ed != start){
res += pre[ed].second;
ed = pre[ed].first;
}
reverse(res.begin() , res.end());
return res;
}
int main(){
string start , seq;
for(int i = 0 ; i < 9 ; i++){
char c;
cin >> c;
start += c;
if(c != 'x') seq += c;
}
//判断逆序对的数量,如果为奇数,直接无解
int cnt = 0;
for(int i = 0 ; i < 8 ; i ++)
for(int j = i + 1 ; j < 8 ; j++)
if(seq[i] > seq[j])
cnt++;
if(cnt % 2) puts("unsolvable");
else cout << bfs(start) << endl;
return 0;
}
例二:
思路:
首先,小根堆每次出堆且是终点第几次出堆就是第几短路(第 k k k次到达终点时的路径长度即为第 k k k短路的长度),然后设计代价函数 h 1 ( x ) h_{1}(x) h1(x):从 x x x点到终点的最短距离,其求法:建反向边跑dijkstra即可
代码:
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;//最小距离,点位置
typedef pair<int, PII> PIII;//预估代价,最小距离,点位置
const int N = 1010, M = 200010;
int n, m, S, T, K;
int h[N], rh[N], e[M], w[M], ne[M], idx;
int dist[N], cnt[N];
bool st[N];
void add(int h[], int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
void dijkstra(){//跑反图,x的预估代价输出到dist[x]
priority_queue<PII, vector<PII>, greater<PII>> heap;
heap.push({0, T});
memset(dist, 0x3f, sizeof dist);
dist[T] = 0;
while (heap.size()){
auto t = heap.top(); heap.pop();
int ver = t.y;
if (st[ver]) continue;
st[ver] = true;
for (int i = rh[ver]; ~i; i = ne[i]){
int j = e[i];
if (dist[j] > dist[ver] + w[i]){
dist[j] = dist[ver] + w[i];
heap.push({dist[j], j});
}
}
}
}
int astar(){
//预估代价,最小距离,点位置
priority_queue<PIII, vector<PIII>, greater<PIII>> heap;
heap.push({dist[S], {0, S}});
while (heap.size()){
auto t = heap.top(); heap.pop();
//点位置,最小距离
int ver = t.y.y, distance = t.y.x;
cnt[ver] ++ ;
if (cnt[T] == K) return distance;
for (int i = h[ver]; ~i; i = ne[i]){
int j = e[i];
if (cnt[j] < K)
heap.push({distance + w[i] + dist[j], {distance + w[i], j}});
}
}
return -1;
}
int main(){
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
memset(rh, -1, sizeof rh);
for (int i = 0; i < m; i ++ ){
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(h, a, b, c);
add(rh, b, a, c);
}
scanf("%d%d%d", &S, &T, &K);
if (S == T) K ++ ;
dijkstra();
printf("%d\n", astar());
return 0;
}