链接
题解
首先很容易想到dp,
d
p
[
i
,
j
,
k
]
dp[i, j, k]
dp[i,j,k]表示从
i
i
i到
j
j
j经过
k
k
k条边的最短路,但是很明显会超时,这里使用矩阵加速,定义:
A
r
[
i
,
j
]
:
从
i
到
j
经
过
r
条
边
的
最
短
路
A^r[i, j]:从i到j经过r条边的最短路
Ar[i,j]:从i到j经过r条边的最短路
那么
∀
i
,
j
∈
[
1
,
n
]
,
(
A
r
+
m
)
[
i
,
j
]
=
m
i
n
1
≤
k
≤
n
{
(
A
r
)
[
i
,
k
]
+
(
A
m
)
[
k
,
j
]
}
\forall i, j \in [1, n],(A^{r + m})[i, j] = min_{1\leq k \leq n}\{(A^{r})[i, k] + (A^{m})[k, j]\}
∀i,j∈[1,n],(Ar+m)[i,j]=min1≤k≤n{(Ar)[i,k]+(Am)[k,j]}.
能够用矩阵进行加速的必要条件是矩阵间的运算满足结合律,结合到这道题就是
i
i
i到
j
j
j经过k条边的最短路可以是先经过
1
1
1条边在经过
k
−
1
k-1
k−1,也可以反过来或者是其他,总之加和为k即可。
代码
#include <bits/stdc++.h>
using namespace std;
#define REP(i, n) for (int i = 1; i <= (n); i++)
#define sqr(x) ((x) * (x))
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 200 + 10;
// const int maxn = 25;
const int maxm = 150000 + 100;
// const int maxm = 30;
const int maxt = 100 + 5;
const int maxk = 1000 + 10;
typedef long long LL;
typedef long double LD;
typedef unsigned long long uLL;
typedef pair<int, int> pii;
typedef pair<double, double> pdd;
const LL unit = 1LL;
const int INF = 0x3f3f3f3f;
const LL Inf = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double inf = 1e15;
const double pi = acos(-1.0);
const LL mod = 1000000007;
int n, m, st, en;
int tot = 200, a[maxn][maxn];
map<int, int> mmap;
struct Mat
{
int a[maxn][maxn];
void init()
{
for (int i = 1; i <= tot; ++i)
for (int j = 1; j <= tot; ++j)
a[i][j] = INF;
}
} s, base;
inline int get_id(int x)
{
if(mmap.count(x))
return mmap[x];
return mmap[x] = ++tot;
}
Mat mul(Mat a, Mat b)
{
Mat c;
c.init();
for (int i = 1; i <= tot; ++i)
for (int j = 1; j <= tot; ++j)
for (int k = 1; k <= tot; ++k)
c.a[i][j] = min(c.a[i][j], a.a[i][k] + b.a[k][j]);
return c;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
// freopen("C:\\Users\\WA_TERMINATOR\\Desktop\\Helloworld\\input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
while (cin >> n >> m >> st >> en)
{
s.init();
tot = 0;
mmap.clear();
int u, v, w;
for (int i = 0; i < m; i++)
{
cin >> w >> u >> v;
int uu = get_id(u), vv = get_id(v);
s.a[uu][vv] = s.a[vv][uu] = min(s.a[uu][vv], w);
}
memcpy(base.a, s.a, sizeof(base.a));
--n;
while (n)
{
if (n & 1)
s = mul(s, base);
base = mul(base, base);
n >>= 1;
}
cout << s.a[get_id(st)][get_id(en)] << "\n";
}
return 0;
}