题意
已知一张图(单向边),起点S和终点F,求从S到F的最短路和比最短路长1的路径的条数之和。
题解
dijkstra维护四个东西,一个是最短距离,一个是次短距离,一个是最短路径的个数,一个是次短路经的个数。
如果用spfa的话(假的spfa),需要上优先队列,每次选距离最小的,这样写的话,就是一个单个顶点能多次入队的dijkstra。不如直接写dijkstra。
代码
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int nmax = 1e3 + 7;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const ull p = 67;
const ull MOD = 1610612741;
struct Gra {
int u, v, w, next;
} e[10007];
int t, head[nmax], dis[nmax][2], num[nmax][2], use[nmax][2];
void init() {
t = 0;
memset(head, -1, sizeof(head));
}
void add(int u, int v, int w) {
e[t].u = u;
e[t].v = v;
e[t].w = w;
e[t].next = head[u];
head[u] = t++;
}
struct node {
int v, flag, dis;
node() {}
node(int v, int dis, int flag) {
this->v = v;
this->dis = dis;
this->flag = flag;
}
bool operator<(const node &a)const {
return dis > a.dis;
}
};
void spfa(int s, int n) {
queue<node>q;
memset(dis, INF, sizeof(dis));
memset(use, 0, sizeof(use));
dis[s][0] = 0;
num[s][0] = 1;
q.push(node(s, 0, 0));
while (!q.empty()) {
node cur = q.front();
int u = cur.v;
int flag = cur.flag;
q.pop();
if (use[u][flag]) continue;
use[u][flag] = 1;
for (int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].v;
if (dis[v][0] > dis[u][flag] + e[i].w) {
dis[v][1] = dis[v][0];
dis[v][0] = dis[u][flag] + e[i].w;
num[v][1] = num[v][0];
num[v][0] = num[u][flag];
q.push(node(v, dis[v][1], 1));
q.push(node(v, dis[v][0], 0));
} else if (dis[v][0] == dis[u][flag] + e[i].w) {
num[v][0] += num[u][flag];
} else if (dis[v][1] > dis[u][flag] + e[i].w) {
dis[v][1] = dis[u][flag] + e[i].w;
num[v][1] = num[u][flag];
q.push(node(v, dis[v][1], 1));
} else if (dis[v][1] == dis[u][flag] + e[i].w) {
num[v][1] += num[u][flag];
}
}
}
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
int n, m;
scanf("%d%d", &n, &m);
init();
for (int i = 1; i <= m; i++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
}
int ss, tt;
scanf("%d%d", &ss, &tt);
spfa(ss, n);
int ans = 0;
ans += num[tt][0];
if (dis[tt][0] + 1 == dis[tt][1])
ans += num[tt][1];
printf("%d\n", ans);
}
return 0;
}