题目分析
来源:acwing
分析:
本题的估价函数是每个点到终点的最短距离,用的是从终点开始的反向dijkstra()算法。
A*算法:终点第一次从优先队列中弹出来的时候,距离一定取到最小值。那么对于第K短路呢?直观的想法是,终点弹第1次是最小的,终点弹第k次是第k小的。
简单的证明
A*算法就是优先队列bfs + 估价函数,所以整个框架还是bfs的框架。
这里用了邻接表来存图,默写了典型的加边函数add(),由于需要估价函数(某点到终点的预估距离),这里写了dijkstra()算法从终点反向求最短距离,存在dist[]数组,作为估价函数。 由于是反向dijkstra,所以建图的时候建了两遍,一个正向建图,用来A*算法的遍历;一个反向建图,用来求估价函数。
ac代码
#include<bits/stdc++.h>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
typedef pair<int, PII> PIII;
const int N = 1010, M = 2e5 + 10;
int n, m, S, T, K;
int h[N], rh[N],e[M], w[M],ne[M], idx;
int dist[N];
bool st[N];// 每个点用没用过
void add(int h[], int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a] , h[a] = idx ++;
}
// 在反向图上dijkstra(),保存估价函数(dist[]距离)
// dist存的是该点到终点的最小距离
void dijkstra(){
priority_queue<PII, vector<PII>, greater<PII>> heap; // 优先队列,小根堆
heap.push({0, T});//终点T是这里的起点 <距离,点编号>
memset(dist, 0x3f, sizeof dist);
dist[T] = 0;
while(heap.size()){
auto t = heap.top();
heap.pop();
int ver = t.y;
if(st[ver]) continue;
st[ver] = true;// 遍历过
// 在反向图上遍历,rh[]数组
for(int i = rh[ver]; i != -1; i = ne[i]){
int j = e[i];
if(dist[j] > dist[ver] + w[i]){
dist[j] = dist[ver] + w[i];
heap.push({dist[j], j});
}
}
}
}
int astar(){
//
priority_queue<PIII,vector<PIII>, greater<PIII>> heap;
// A* 算法,从起点开始搜
heap.push({dist[S],{0, S}});// 估价值,{真实值,编号}
int cnt = 0; // 终点遍历几次
// 无解的话,返回-1
if(dist[S] == 0x3f3f3f3f) return -1;
while(heap.size()){
auto t = heap.top();
heap.pop();
int ver = t.y.y, distance = t.y.x; //从起点到该点的真实距离
if(ver == T) cnt ++; // 遍历一遍终点, cnt++,直到第K次
// 这个点是终点,并且是第k短路,直接返回第k短路的真实距离distance
if(cnt == K) return distance;
// 正向扩展所有的边
// 用 起点到该点的真实距离+ 该点到终点的估价距离来作为标准
// distance + w[i] + dist[j]
for(int i = h[ver]; i != -1; i = ne[i]){
int j = e[i];
heap.push({distance + w[i] + dist[j],{distance + w[i], j}});
}
}
return -1;
}
int main(){
cin >> n >> m;
// 初始化表头
memset(h, -1, sizeof h);
memset(rh, -1, sizeof rh);
for(int i = 0; i < m; i ++){
int a, b, c;
scanf("%d%d%d",&a, &b, &c);
add(h, a, b, c);// 建正边
add(rh, b, a, c); // 建反边
}
scanf("%d%d%d",&S, &T, &K);
if( S == T) K ++;
dijkstra();
cout << astar() << endl;
}