题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1688
题意:给出一个有重边的有向图,求给出的2个点间有多少条最短路以及和最短路的路程差1的次短路
思路:dijstra算法魔改一下,用队列进行操作,dist[i][0]记录到i点最短路,再用dist[i][1]记录次短路,cnt[i][0]和cnt[i][1]数组则记录最短路和次短路的个数。
每次进行松弛操作的时候有4种情况分别是:
1.到i的距离比记录的最短距离要短,更新dist[i][0]以及dist[i][1],入队进行松弛操作
2.到i的距离比记录的次短路要短但比最短路长,更新dist[i][1],入队进行松弛操作
3.到i的距离和记录最短路一样长,更新最短路的个数
4.到i的距离和记录次短路一样长,更新次短路的个数
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
#include <vector>
#define maxn 10030
#define inf 0x3f3f3f3f
using namespace std;
struct Edge
{
int v,w;
Edge(int a,int b):v(a),w(b){};
};
struct Node
{
int k,dis,pos;
bool operator <(const Node & q)const
{
return dis>q.dis;
}
}st;
priority_queue<Node> que;
int vis[1030][2],dist[1030][2],cnt[1030][2];
vector<Edge> edge[1030];
void init()
{
memset(edge,0,sizeof(edge));
memset(vis,0,sizeof(vis));
memset(dist,inf,sizeof(dist));
memset(cnt,0,sizeof(cnt));
while (!que.empty())
que.pop();
}
void dijstra()
{
que.push(st);
while (!que.empty())
{
Node tmp=que.top();
que.pop();
int u=tmp.pos,k=tmp.k;
if (vis[u][k])
continue;
else
vis[u][k]=1;
for (int i=0;i<edge[u].size();i++)
{
int dis=edge[u][i].w+tmp.dis;
int v=edge[u][i].v;
Node nxt;
if (dis<dist[v][0])
{
cnt[v][1]=cnt[v][0];
dist[v][1]=dist[v][0];
nxt.dis=dist[v][1];
nxt.pos=v;
nxt.k=1;
que.push(nxt);
cnt[v][0]=cnt[u][0];
dist[v][0]=dis;
nxt.dis=dist[v][0];
nxt.pos=v;
nxt.k=0;
que.push(nxt);
}
else if (dis==dist[v][0])
{
cnt[v][0]+=cnt[u][0];
}
else if (dis<dist[v][1])
{
cnt[v][1]=cnt[u][k];
dist[v][1]=dis;
nxt.dis=dis;
nxt.pos=v;
nxt.k=1;
que.push(nxt);
}
else if (dis==dist[v][1])
{
cnt[v][1]+=cnt[u][k];
}
}
}
return;
}
int main()
{
int t,n,m;
scanf("%d",&t);
while (t--)
{
init();
scanf("%d%d",&n,&m);
for (int i=0;i<m;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
edge[u].push_back(Edge(v,w));
}
int stay,end;
scanf("%d%d",&stay,&end);
st.pos=stay;
st.dis=0;
st.k=0;
dist[stay][0]=0;
cnt[stay][0]=1;
dijstra();
int res=cnt[end][0];//cout<<dist[end][0]<<":"<<dist[end][1]<<endl;
if (dist[end][1]==dist[end][0]+1)
res+=cnt[end][1];
printf("%d\n",res);
}
}