K短路求解算法常用的有djstra + A* 和 Yen算法。本文主要讲解djstra + A*
先了解下A*中的估值函数
f(n)=g(n)+h(n)
f
(
n
)
=
g
(
n
)
+
h
(
n
)
,显然我们可以通过记录到达某一点的花费。即如下节点
struct node {
int h,g,v;
}
- 不断更新从某一点更新,将其压入优先队列之中。
- 每取出一个点时需保证其为由该点出发到达终点的最小花费(即无论通过何种方式更新都不能得到比他更小的值)。
- 同时应当为优先队列中的估计花费最小节点。即保证其 f(n) f ( n ) 最小。
- 当每取出一个节点v,对应的cnt[v]++。即取出节点为该点的第cnt[v]短路。由此也可得到求k短路时从每个点最多更新k次
- 如何求出h
- 建立相反的图
- 有终点出发跑最短路对应每个点的dijstra,显然对应的dist为该点到终点的最小花费。恰好能保证A*中的h的性质。故可在node中省去h
模板
#include <cstring>
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
#define mem(a,b) memset(a,b,sizeof a)
const int INF = 1<<30;
#define sci(a) scanf("%d",&a)
typedef pair<int,int> P;
const int MAX_N = 1005;
const int MAX_E = 100010;
struct edge {
int to,cost,next;
} es[MAX_E],es2[MAX_E];
int N,M;
int head[MAX_N],len;
int head2[MAX_N],len2;
int dist[MAX_N];
int cnt[MAX_N];
void addedge(int from,int to,int cost) {
es[len] = {to,cost,head[from]};
head[from] = len++;
es2[len2] = {from,cost,head2[to]};
head2[to] = len2++;
}
priority_queue<P,
vector
<P>,greater<P> > que;
void djstra(int t) {
fill(dist ,dist + MAX_N,INF);
dist[t] = 0;
while (!que.empty()) que.pop();
que.push({0,t});
while (!que.empty()) {
P p = que.top(); que.pop();
int u = p.second;
if (p.first > dist[u]) continue;
for (int i = head2[u] ; ~i ; i = es2[i].next) {
edge& e = es2[i];
if (dist[u] + e.cost < dist[e.to]) {
dist[e.to] = dist[u] + e.cost;
que.push({dist[e.to],e.to});
}
}
}
}
struct Node {
int h,v;
bool
operator
<(const Node& node2)const {
return h + dist[v] > node2.h + dist[node2.v];
}
};
priority_queue<Node> q;
int Astar(int s,int t,int k) {
if (dist[s] == INF) return -1;
if (s == t) k++;
while (!q.empty()) q.pop();
mem(cnt,0);
q.push({0,s});
while (!q.empty()) {
Node node1 = q.top(); q.pop();
if (cnt[node1.v] > k) continue;
int u = node1.v;
cnt[u]++;
if (cnt[t] == k) return node1.h;
for (int i = head[u] ; ~i ; i = es[i].next) {
q.push({node1.h + es[i].cost,es[i].to});
}
}
return -1;
}
void init() {
mem(head,-1);
mem(head2,-1);
len = len2 = 0;
}
int main() {
int u,v,cost;
int S,T,K;
sci(N); sci(M);
init();
for (int i = 0;i < M;i++) {
sci(u); sci(v);sci(cost);
addedge(u,v,cost);
}
sci(S); sci(T); sci(K);
djstra(T);
printf("%d\n",Astar(S,T,K));
}