题目传送门:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=2760
【题目描述】
给一个N个点的图,用邻接矩阵表示点的联通情况,-1表示无边,否则表示边的长度。定义两条路径非重叠为两条路径没有公共边的情况,给出源点和汇点,求从源点到汇点有多少条非重叠的最短路径。
【输入格式】
输入包含多组数据,每组数据第一行为N,接下来给出一个N*N的矩阵表示有向图,每一个元素要么是非负整数,表示边权,要么是-1,表示没有边。最后给出两个点s,t表示源点和汇点,点的编号0~N-1。
【输出格式】
对于每组测试数据,输出一个整数表示非重叠路径数,如果源、汇点相同,就输出inf。
【样例输入】
4
0 1 1 -1
-1 0 1 1
-1 -1 0 1
-1 -1 -1 0
0 3
5
0 1 1 -1 -1
-1 0 1 1 -1
-1 -1 0 1 -1
-1 -1 -1 0 1
-1 -1 -1 -1 0
0 4
【样例输出】
2
1
【备注】
N<=100
【题目分析】
非常好的网络流板题。
考虑建图,我们先从s开始找到一条最短路s->t,然后s到所有点的最短路径就都找到了,那么我们再从s开始找一遍,如果对于点u,v满足dis[u]+w(u,v)=dis[v],那么这条边一定可以作为最短路上的边,在网络流的图上我们就将u,v连边,因为只能走一次,所以流量上限设为1。最后在形成的图上跑一遍网络流即可(不过最短路我用的是Floyd,但也只有70ms)。
【代码~】
#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e4+10;
const int INF=0x3f3f3f3f;
struct node{
int nxt,w,to;
}e[MAXN];
int cnt,t,st,q[MAXN],n,m,h,d,num,s,tt;
int depth[MAXN],head[MAXN],cur[MAXN];
int ma[300][300],dis[300][300];
bool bfs()
{
memset(depth,-1,sizeof(depth));
depth[st]=0;
int r=1,l=1;
r++;
q[r]=st;
int v,u;
while(l<r)
{
l++;
u=q[l];
for(int i=head[u];i!=-1;i=e[i].nxt)
{
v=e[i].to;
if(depth[v]!=-1)
continue;
if(e[i].w==0)
continue;
depth[v]=depth[u]+1;
r++;
q[r]=v;
if(v==t)
return true;
}
}
return false;
}
int dfs(int x,int mx)
{
if(x==t||mx==0)
return mx;
int f,flow=0,v,ret=0;
for(int i=head[x];i!=-1;i=e[i].nxt)
{
v=e[i].to;
if(depth[x]+1!=depth[v])
continue;
if((f=dfs(v,min(mx,e[i].w))))
{
e[i].w-=f;
e[i^1].w+=f;
flow+=f;
ret+=f,mx-=f;
if(!mx)
break;
}
}
if(ret==0)
depth[x]=-1;
return flow;
}
int dinic()
{
int tmp=0,maxflow=0;
while(bfs())
{
while(tmp=dfs(st,INF))
maxflow+=tmp;
}
return maxflow;
}
void add(int a,int b,int c)
{
e[cnt].to=b,e[cnt].nxt=head[a],e[cnt].w=c,head[a]=cnt,cnt++;
e[cnt].to=a,e[cnt].nxt=head[b],e[cnt].w=0,head[b]=cnt,cnt++;
}
int main()
{
while(~scanf("%d",&n))
{
memset(head,-1,sizeof(head));
cnt=0;
st=n+2;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
scanf("%d",&ma[i][j]);
memcpy(dis,ma,sizeof(ma));
for(int i=1;i<=n;i++)
dis[i][i]=0;
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
dis[i][j]=(dis[i][j]==-1?INF:dis[i][j]);
for(int k=1;k<=n;++k)
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
if(dis[i][k]<INF&&dis[k][j]<INF&&dis[i][j]>dis[i][k]+dis[k][j])
dis[i][j]=dis[i][k]+dis[k][j];
scanf("%d%d",&s,&t);
s++,t++;
for(int i=1;i<=n;i++)
{
if(dis[s][i]==INF)
continue;
for(int j=1;j<=n;j++)
{
if(i==j)
continue;
if(ma[i][j]==-1||dis[s][j]==INF)
continue;
if(dis[s][i]+ma[i][j]==dis[s][j])
add(i,j,1);
}
}
if(s==t)
{
printf("inf\n");
continue;
}
add(st,s,INF);
printf("%d\n",dinic());
}
return 0;
}