题意:
给定一个有 n 个点,m 条边的有向带权图,求 s 到 t 的第 k 短路。(n, k <= 1e3, m <= 1e5)
链接:
https://vjudge.net/problem/POJ-2449
解题思路:
k 短路练手,注意特判 s == t 。
参考代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define pb push_back
#define sz(a) ((int)a.size())
#define mem(a, b) memset(a, b, sizeof a)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e3 + 5;
const int maxm = 1e6 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
struct Node{
int ls, rs, dis, val, to;
} tr[maxm];
vector<pii> G1[maxn], G2[maxn];
priority_queue<pii, vector<pii>, greater<pii> > q;
int dis[maxn], fa[maxn], rt[maxn];
int n, m, sp, tp, k, tot;
int merge(int x, int y, int f){
if(!x || !y) return x + y;
if(tr[x].val > tr[y].val) swap(x, y);
if(f) tr[++tot] = tr[x], x = tot;
tr[x].rs = merge(tr[x].rs, y, f);
if(tr[tr[x].ls].dis < tr[tr[x].rs].dis) swap(tr[x].ls, tr[x].rs);
tr[x].dis = tr[tr[x].rs].dis + 1;
return x;
}
void init(){
for(int i = 1; i <= n; ++i) G1[i].clear(), G2[i].clear();
for(int i = 1; i <= n; ++i) fa[i] = rt[i] = 0;
while(!q.empty()) q.pop();
tot = 0;
}
void dij(int s, vector<pii> G[]){
for(int i = 1; i <= n; ++i) dis[i] = inf;
dis[s] = 0, q.push({dis[s], s});
while(!q.empty()){
int u = q.top().second, d = q.top().first; q.pop();
if(d != dis[u]) continue;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(dis[v] > dis[u] + w){
dis[v] = dis[u] + w;
q.push({dis[v], v});
}
}
}
}
void build(vector<pii> G[]){
for(int u = 1; u <= n; ++u){
if(dis[u] == inf) continue;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(dis[u] == dis[v] + w && !fa[u]) fa[u] = v;
else if(dis[v] != inf){
tr[++tot] = {0, 0, 0, dis[v] + w - dis[u], v};
rt[u] = merge(rt[u], tot, 0);
}
}
q.push({dis[u], u});
}
while(!q.empty()){
int u = q.top().second; q.pop();
if(fa[u]) rt[u] = merge(rt[u], rt[fa[u]], 1);
}
}
int solve(){
dij(tp, G2); build(G1);
if(sp == tp) ++k;
if(!rt[sp]) return -1;
if(k == 1) return dis[sp];
int res = k - 1;
q.push({tr[rt[sp]].val, rt[sp]});
while(!q.empty()){
int u = q.top().second, d = q.top().first; q.pop();
if(--res == 0) return dis[sp] + d;
if(int v = tr[u].ls) q.push({d - tr[u].val + tr[v].val, v});
if(int v = tr[u].rs) q.push({d - tr[u].val + tr[v].val, v});
if(int v = rt[tr[u].to]) q.push({d + tr[v].val, v});
}
return -1;
}
int main(){
// ios::sync_with_stdio(0); cin.tie(0);
while(scanf("%d%d", &n, &m) != EOF){
init();
for(int i = 1; i <= m; ++i){
int u, v, w; scanf("%d%d%d", &u, &v, &w);
G1[u].pb({w, v}), G2[v].pb({w, u});
}
scanf("%d%d%d", &sp, &tp, &k);
int ret = solve();
printf("%d\n", ret);
}
return 0;
}