题目描述抽象来看,是指有一个有向图,问一个点经过N条边到另一个点的最短距离(边可重复走)
为了搞这题...去研究了下矩阵乘法...我不是计算机专业~~又看了下他们的离散数学教材...有一个例子是说求两点间经过N条边到达的方案数..Mtrix67的Blog的第八题讲的也是这个问题....
首先看经过N条边方案数的这个问题...也就是理解一下这个过程...用一个邻接矩阵来存图...点 ( i , j ) 代表 i 到 j 有多少条路...最初矩阵A的初始化时( i , j ) 为两点i到j直接的边数...那么A1存的实际就是每两点只经过一条边到达的方案数...那么看一下 A^2 也就是 A*A ... 做矩阵乘法时是 MutiMtrix [ i ] [ j ] = sum ( Mtrix1 [ i ] [ k ] * Matrix2 [ k ] [ j ] ) < k=1..点数> ...那么就是说枚举所有的中间点(k) sum( A 中i到k点的方案数* A中 k 到 j 的方案数) 很明显能求出每两点之间经过两条边到达的方案数...也就是说 A^2 就代表综上...同理可证A^3代表两点之间经过3条边到达的方案数...A^k代表两点之间经过k条边到达的方案数..
理解了这个例子后再来看这道题...这道题虽然也是经过多少多少条边两点到达..但求的是最短距离...求最短距离..又联想方案数的应该和矩阵有关系..很容易能想到Floyd...Floyd在求最短路径时枚举中间点..不断更新两点两点的最短距离...回想一下Floyd的更新的方程
if ( Dist [ i ] [ j ] < Dist [ i ] [ k ] + Dist [ k ] [ j ] ) Dist [ i ] [ j ] = Dist [ i ] [ k ] + Dist [ k ] [ j ]
这个表达式是不是很酷似矩阵乘法的运算式?多了一层判断再更新,把乘号变成了加号...
因为题目所给的两点最多有一条边...令一个邻接矩阵A表示两两点的初始关系...也就是题目所给的两点相连的情况..( i , j ) 是边的权值...那么( i , j )显然是 i 到 j 经过一条边最短路径长度...定义一个矩阵乘法的形式Floyd的更新方式的矩阵运算:
pp muti(pp a,pp b)
{
pp h;
int i,j,k;
for (i=1;i<=n;i++)
for (j=1;j<=n;j++) h.s[i][j]=oo;
for (k=1;k<=n;k++)
for (i=1;i<=n;i++)
for (j=1;j<=n;j++)
if (h.s[i][j]>a.s[i][k]+b.s[k][j])
h.s[i][j]=a.s[i][k]+b.s[k][j];
return h;
}
如果这时把 a=A b=A来做....得到的结果显然是两点到达经过两条边所需要的最短路径,运用Floyd以及求方案数的思维...为什么这个很类似Floyd的式子求出来的是确定了经过边数的最短距离?因为这个更新和Floyd不同的是更新到一个新的矩阵上去了而不是直接像Floyd的自己更新自己...所以在一更新时...不会出现自己刚更新的值又来继续更新...并且如果a,b矩阵能分辨知道是两点间几条边的最短距离..显然得到的h矩阵是经过a(其矩阵表示的经过边数)+b(其矩阵表示的经过边数)条边两点的最短距离。
这个矩阵运算的方式除了更新,形式上和矩阵乘法时一样的...所以可以运用矩阵乘法的性质来二分求解...例如如果要求 s 到 e 经过N条边到达的最短距离~~实际上是求 A 矩阵做N次后 ( s, e )的值...这里直接就看成乘法来思考..也就是求A^N这个矩阵...明显的用二分的方法来解决就可以了...Mtrix67以及我前一篇文章关于这个方法已经说得很清楚了..
这道题要注意的一点就是虽然点的标号可能是1-1000...但边最多只有100个...所以点最多也就100来个..所以要把点这里处理下~~把离散的点压成从1开始连续的好处理得多...
Program:
#include<iostream>
#define MAXN 106
#define ok printf("Yes %d!!\n",p)
#define oo 1000000001
using namespace std;
struct pp
{
int s[MAXN][MAXN];
}a,h;
struct p1
{
int x,y,k;
}line[MAXN];
int n,t,s,e,i,m,x,y,k,point[1005];
bool had[1005];
pp muti(pp a,pp b)
{
pp h;
int i,j,k;
for (i=1;i<=n;i++)
for (j=1;j<=n;j++) h.s[i][j]=oo;
for (k=1;k<=n;k++)
for (i=1;i<=n;i++)
for (j=1;j<=n;j++)
if (h.s[i][j]>a.s[i][k]+b.s[k][j])
h.s[i][j]=a.s[i][k]+b.s[k][j];
return h;
}
pp find(int p)
{
pp h;
if (p==1) return a;
h=find(p/2);
h=muti(h,h);
if (p%2) h=muti(h,a);
return h;
}
int main()
{
while (~scanf("%d%d%d%d",&n,&t,&s,&e))
{
memset(had,false,sizeof(had));
for (i=1;i<=t;i++)
{
scanf("%d%d%d",&line[i].k,&line[i].x,&line[i].y);
had[line[i].x]=had[line[i].y]=true;
}
m=n; n=0;
for (i=1;i<=1000;i++)
if (had[i])
{
n++;
point[i]=n;
}
for (y=1;y<=n;y++)
for (x=1;x<=n;x++)
a.s[x][y]=oo;
for (i=1;i<=t;i++)
{
x=point[line[i].x]; y=point[line[i].y]; k=line[i].k;
a.s[x][y]=a.s[y][x]=k;
}
h=find(m);
printf("%d\n",h.s[point[s]][point[e]]);
}
return 0;
}