题目描述:
给定一张由T条边构成的无向图,点的编号为1~1000之间的整数。
求从起点S到终点E恰好经过N条边(可以重复经过)的最短路。
注意: 数据保证一定有解。
输入格式
第1行:包含四个整数N,T,S,E。
第2..T+1行:每行包含三个整数,描述一条边的边长以及构成边的两个点的编号。
输出格式
输出一个整数,表示最短路的长度。
数据范围
2≤T≤100,
2≤N≤10^6
输入样例:
2 6 6 4
11 4 6
4 4 8
8 4 9
6 6 8
2 6 9
3 8 9
输出样例:
10
分析:
本题是求恰好经过k条边的最短路径,y总对这题的讲解比较粗糙 ,很多先修知识点并没有讲解,这题是可以将很多重要的知识点串起来的好题,仔细分析会大有收获的。有边数限制的最短路径问题,第一反应自然是Bellman_Ford算法,只需对Bellman Ford算法的逻辑稍加修改,再吸吸氧应该就可以AC了,下面详细介绍下用矩阵快速幂的方法来求解本题。
一、图论中的矩阵乘法
设g[i][j] = 1表示i到j之间存在关联边,g[i][j] = 0表示不存在关联边,t[i][j]表示从i恰好经过两条边到达j的路径条数,很容易得出状态转移方程t[i][j] += g[i][k] * g[k][j],k从1到n,如何得到这个式子的呢?只有i到k和k到j都存在边的时候,从i到k再到j才是刚好经过两条边的路径,所有这种路径条数的和就是t[i][j]了,转化为代码的形式就是:
for(int k = 1;k <= n;k++)
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
t[i][j] += g[i][k] * g[k][j];
观察上面的代码可以发现,上面的代码正是求矩阵t = g * g的代码,也就是说,对没有边权的图的邻接矩阵g而言,g * g中的第i行第j列就表示了从i到j存在几条路径长度为2的路径,同理g * g * g = t * g就等于从i先经过两条边再经过一条边到达j的路径条数,也就是经过三条边的路径条数,以此类推可以推出,g的k字方中存储的就是任意两点间经过k条边路径的条数。这里我们可以看出两点,其一是在图论中矩阵乘法具有求两点间经过若干条边路径条数的功能;其二是对邻接矩阵做一次自乘就可以得到结果两条边的路径数,如果需要求经过更多条边的路径数,就需要求g的k字方了。
二、广义的矩阵乘法
上面的问题的实现就是矩阵乘法的代码,但是现实中很多问题并不是要我们求路径的条数,比如本题就是求最短路径长度。现在设g就是图的邻接矩阵,t[i][j]表示从i恰好经过两条边到达j的最短路径长度,则t[i][j] = min(g[i][k] + g[k][j]);转化为代码的形式就是:
for(int k = 1;k <= n;k++)
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
t[i][j] = min(t[i][j],g[i][k] * g[k][j]);
可以发现,这次的代码与矩阵乘法的代码很相似,可以看作广义的矩阵乘法。
三、矩阵快速幂对动态规划的优化
首先还是介绍矩阵快速幂是如何求普通矩阵的n字方的,之前写过的关于矩阵快速幂的题解见AcWing 205 斐波那契。首先快速幂是如何求a^b的,比如求3^13,13 = 1101,也就是13 = 8 + 4 + 1,3^13 = 3^8 * 3^4 * 3^1,普通数的乘法具有结合律,a*b*c = a * (b * c),所以快速幂可以将一系列的乘法拆分成几组乘法。对于矩阵而言也具有结合律,A * B * C = A * (B * C),矩阵t = g * g的展开式就是t[i][j] = sum(g[i][.] * g[.][j]),g * g * g = t * g = sum(t[i][.] * g[.][j]) = sum(g[i][.] * t[.][j]) = g * t,这就是矩阵乘法具有结合律的原因,有了结合律,六个矩阵相乘,就可以先算前两个矩阵的乘积和后四个矩阵的乘积,最后再乘到一起即可。
void mul(int c[][N],int a[][N],int b[][N]){
static int t[N][N];
memset(t,0,sizeof t);
for(int k = 1;k <= n;k++)
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
t[i][j] += a[i][k] + b[k][j];
memcpy(c,t,sizeof t);
}
void qmi(){
memset(f,0,sizeof f);
for(int i = 1;i <= n;i++) f[i][i] = 1;
while(k){
if(k & 1) mul(f,f,g);
mul(g,g,g);
k >>= 1;
}
}
上面的代码就是矩阵快速幂对普通矩阵乘法的优化,也就是将快速幂代码中的普通乘法修改为矩阵乘法mul即可。看起来很简单,但是很多细节不自己去思考是不知道为什么要这么设置的,比如qmi里面的代码,为什么对存储矩阵乘法的结果矩阵f要初始化为单位矩阵呢?我们考虑边界情况,k = 1时候,f = f * g,我们需要g^1 = g,也就是需要f * g = g,所以f的初值需要设置为单位矩阵。矩阵乘法的函数里对t的初值设置为0无可厚非,为什么需要先定义一个备份矩阵t存下矩阵乘法的结果再在计算完之后拷贝给矩阵c呢?是因为很可能传进去的a、b、c都是一个矩阵,对同一个二维数组做自乘必然会改变矩阵原始的值引发运算错误,所以需要t存储中间的计算结果。
再来说下矩阵快速幂对广义矩阵乘法的优化。我们将二中实现t[i][j] = min(g[i][k] + g[k][j]);的代码,也就是广义矩阵乘法用mul表示,mul可以求g mul g,但是如何求若干个g相乘呢?考察min运算是否具有结合律,t[i][j] = min(g[i][.] + g[.][j]);,则min(t[i][.] + g[.][j]) = min(g[i][.] + t[.][j]),也就是说广义的矩阵乘法mul也具有结合律,所以也可以用矩阵快速幂对其做优化。上面说过,广义矩阵乘法mul可以求i经过两条边到达j的最短路径长度,则对g做k次自乘就可以求经过k条边的最短路径长度了。原理很好理解,t = g mul g,t[i][k] + g[k][j]就是i经过k到达j也就是经过三条边的路径长度,在其中取个min就是经过三条边的最短路径长度了,k条边以此类推。
void mul(int c[][N],int a[][N],int b[][N]){
static int t[N][N];
memset(t,0x3f,sizeof t);
for(int k = 1;k <= n;k++)
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
t[i][j] = min(t[i][j],a[i][k] + b[k][j]);
memcpy(c,t,sizeof t);
}
void qmi(){
memset(f,0x3f,sizeof f);
for(int i = 1;i <= n;i++) f[i][i] = 0;
while(k){
if(k & 1) mul(f,f,g);
mul(g,g,g);
k >>= 1;
}
}
上面是矩阵快速幂优化本题中广义矩阵乘法的代码。对t[i][j]的求法改变是理所应当的,但是仔细观察会发现,前面矩阵快速幂对普通矩阵乘法的优化强调过的地方在这里就完全不一样了。首先,qmi函数里结果矩阵f的初值变成了主对角线元素都是0,其他元素都是INF了,这样做能否使得k = 1时,g * f还等于g呢?mul中求a[i][k] + b[k][j]时,也就是求f[i][k] + g[k][j],f[i][k]除了在k = i时候是0,其他时候都是INF,不会更新最小值,而f[i][i] + g[i][j]就算i到j的距离,也就没有改变g的初值,这就是f的初值这样设置的原因。对于mul中备份数组t的初值,设置为INF就理所应该了,因为后面需要求的是最小值。
本题的完整代码如下,值得注意的是点的编号最大是1000,最最多只有100条边,也就是最多200个节点,所以可以对节点编号做下离散化,较为简单,参考代码即可。还有个要注意的地方时,本题读取边的信息的时候是先读取边长,在读取两个节点的编号,一般题目都是最后读取边权,要注意不要写错数据的读取代码。
#include <iostream>
#include <unordered_map>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 205;
int n,k,g[N][N],f[N][N];
unordered_map<int,int> um;
void mul(int c[][N],int a[][N],int b[][N]){
static int t[N][N];
memset(t,0x3f,sizeof t);
for(int k = 1;k <= n;k++)
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
t[i][j] = min(t[i][j],a[i][k] + b[k][j]);
memcpy(c,t,sizeof t);
}
void qmi(){
memset(f,0x3f,sizeof f);
for(int i = 1;i <= n;i++) f[i][i] = 0;
while(k){
if(k & 1) mul(f,f,g);
mul(g,g,g);
k >>= 1;
}
}
int main(){
int m,s,e,a,b,c;
cin>>k>>m>>s>>e;
um[s] = ++n;
if(!um.count(e)) um[e] = ++n;
s = um[s],e = um[e];
memset(g,0x3f,sizeof g);
for(int i = 0;i < m;i++){
cin>>c>>a>>b;
if(!um.count(a)) um[a] = ++n;
if(!um.count(b)) um[b] = ++n;
a = um[a],b = um[b];
g[a][b] = g[b][a] = min(g[a][b],c);
}
qmi();
cout<<f[s][e]<<endl;
return 0;
}