这题是裸的最小割,先从1节点bfs,再从n节点bfs,就可以找出最短路的所有路径,然后直接套模板就好了。
#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <math.h>
#include <string.h>
const int maxlongint=1 << 31 -1;
int n, m, tot, ans, s, t, x, y, z;
int a[500001], b[500001], next[500001];
int last[5005], r[5005], w[5005], count[5005], cur[5005], fa[5005], dist[5005], dat[5005];
int tar[5001], nextt[500001], distt[5001], costt[500001], lastt[5001], sum, d[5001];
bool flag[5001];
int insert(int x, int y, int z) {
tot++;
a[tot] = y;
b[tot] = z;
next[tot] = last[x];
last[x] = tot;
tot++;
a[tot] = x;
b[tot] = 0;
next[tot] = last[y];
last[y] = tot;
}
int addedge(int x, int y, int z) {
sum++;
tar[sum] = y;
costt[sum] = z;
nextt[sum] = lastt[x];
lastt[x] = sum;
}
int min(int x, int y) {
if (x < y) return x ;
else return y;
}
int sap(int s, int t) {
int i, j, k, x, p, sum;
count[0] = 1;
count[1] = t-1;
for (i = 1; i <= t-1; i++) dist[i] = 1;
dist[t] = 0;
for (i = 1; i <= t; i++) {
cur[i] = last[i];
fa[i] = 0;
dat[i] = 0;
}
dat[s] = maxlongint;
x = s; sum = 0;
while (1) {
k = cur[x];
while (k > 0) {
if ((b[k] > 0)&&(dist[a[k]] == dist[x]-1)) break;
k = next[k];
}
if (k > 0) {
cur[x] = k;
fa[a[k]] = k;
dat[a[k]] = min(dat[x],b[k]);
x = a[k];
if (x == t) {
sum = sum+dat[x];
while (x!=s) {
b[fa[x]] = b[fa[x]]-dat[t];
b[fa[x] xor 1] = b[fa[x] xor 1]+dat[t];
x = a[fa[x] xor 1];
}
}
} else {
count[dist[x]]--;
if (count[dist[x]] == 0) return sum;
k = last[x];
dist[x] = t+1;
while (k!=0) {
if ((b[k] > 0)&&(dist[a[k]]+1 < dist[x])) {
dist[x] = dist[a[k]]+1;
cur[x] = k;
}
k = next[k];
}
count[dist[x]]++;
if (dist[s] > t) return sum;
if (x != s) x = a[fa[x] xor 1];
}
}
}
int main() {
int tt;
scanf("%d", &tt);
while (tt--) {
memset(lastt, 0, sizeof(lastt));
memset(distt, 0, sizeof(distt));
sum = 0;
scanf("%d %d", &n, &m);
for (int i = 1; i <= m; i++) {
scanf("%d %d %d", &x, &y, &z);
addedge(x, y, z);
addedge(y, x, z);
}
d[1] = 1;
distt[1] = 1;
memset(flag, 0, sizeof(flag));
flag[1] = true;
int l = 0, r = 1;
while (l < r) {
l++;
int k = lastt[d[l]];
while (k != 0) {
if (!flag[tar[k]]) {
flag[tar[k]] = true;
r++;
d[r] = tar[k];
distt[tar[k]] = distt[d[l]]+1;
}
k = nextt[k];
}
}
memset(d, 0, sizeof(d));
memset(last, 0, sizeof(last));
memset(flag, 0, sizeof(flag));
tot = 1;
d[1] = n;
l = 0; r = 1;
while (l < r) {
l++;
int k = lastt[d[l]];
while (k != 0) {
if (distt[tar[k]] == distt[d[l]]-1) {
if (!flag[tar[k]]) {
r++;
d[r] = tar[k];
flag[tar[k]] = true;
}
insert(tar[k], d[l], costt[k]);
}
k = nextt[k];
}
}
int ans = sap(1, n);
printf("%d\n", ans);
}
}