一、题目概述
题目链接:Libre OJ。
给出一张图,求出 a n s = ∑ i = 1 n ∑ j = 1 n [ i ≠ j ] d k ( i , j ) ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j]\texttt d^k(i,j) ans=∑i=1n∑j=1n[i̸=j]dk(i,j) 的值,其中 d ( x , y ) \texttt d(x,y) d(x,y) 表示从 x x x 到 y y y 的最短路。
对于所有数据满足
1
≤
n
≤
1
0
5
,
1
≤
k
≤
1
0
9
1\le n\le 10^5,1\le k\le 10^9
1≤n≤105,1≤k≤109,保证给定的图
G
G
G 满足题中要求,且不存在重边。
subtask
1:
5
%
\texttt{subtask 1:}~ 5\%
subtask 1: 5%,满足
n
≤
1000
n\le 1000
n≤1000 。
subtask
2:
10
%
\texttt{subtask 2:}~10\%
subtask 2: 10%,满足
k
=
1
k=1
k=1 。
subtask
3:
15
%
\texttt{subtask 3:}~15\%
subtask 3: 15%,满足
k
=
2
k=2
k=2 。
subtask
4:
30
%
\texttt{subtask 4:}~30\%
subtask 4: 30%,满足
G
G
G 中存在一条边
(
u
,
u
)
(u,u)
(u,u) 。
subtask
5:
40
%
\texttt{subtask 5:}~40\%
subtask 5: 40%,无额外限制。
二、解题思路
算法0
首先肯定可以 O ( n 3 ) \mathcal O(n^3) O(n3) 跑 Floyd \texttt{Floyd} Floyd ,但是这个是集训队互测的题,可能是由于集训队大佬们都不屑于写,于是就没有这一档部分分。
算法1
数据满足 n ≤ 1000 n\le 1000 n≤1000 。
考虑分类讨论:
如果这是一棵树的话,那么问题就是求 a n s = ∑ i = 1 n ∑ j = 1 n [ i ≠ j ] ( dep i + dep j − 2 dep LCA ( i , j ) ) k ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^k ans=∑i=1n∑j=1n[i̸=j](depi+depj−2depLCA(i,j))k ,直接枚举即可。
如果原图是一棵基环树的话,使用基环树的常规套路,先找出环,并且删去环上的任意一条边,剩下的边按照树的方案做一次。做完之后,我们单独考虑这条边 ( u , v ) (u,v) (u,v) 对最短路的贡献,即有哪些路径可以使用 s ⟶ u → v ⟶ t s\longrightarrow u\rightarrow v\longrightarrow t s⟶u→v⟶t 作为最短路。于是我们先从 u u u 和 v v v 向其它点跑一边最短路,然后 O ( n 2 ) \mathcal O(n^2) O(n2) 枚举 s s s 和 t t t ,然后暴力更新即可。
算法2
数据满足 k = 1 k=1 k=1 。
考虑分类讨论:
如果原图是树,那么问题就是求 a n s = ∑ i = 1 n ∑ j = 1 n [ i ≠ j ] dep i + dep j − 2 dep LCA ( i , j ) ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j]\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)} ans=∑i=1n∑j=1n[i̸=j]depi+depj−2depLCA(i,j) ,直接枚举每个点作为路径的端点,和作为 LCA \texttt{LCA} LCA 对答案的贡献即可。
如果原图是基环树,那么我们可以分析一下这张图的 dfs
树会长成什么样:
这就意味着,新加入的边如果对某一条路径有影响,那么这条路径的一定是一端在 v1 \texttt{v1} v1 的子树之外,另一端在 vn \texttt{vn} vn 的子树之内。通过这个性质,我们就可以将基环树上的问题转化为树上的问题求解了。
这样的话,我们还可以顺便解决另外一个部分分:
数据满足 k = 2 k=2 k=2 。
答案变为 a n s = ∑ i = 1 n ∑ j = 1 n [ i ≠ j ] ( dep i + dep j − 2 dep LCA ( i , j ) ) 2 ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^2 ans=∑i=1n∑j=1n[i̸=j](depi+depj−2depLCA(i,j))2 ,将完全平方展开后利用上面的方法维护即可。
算法3
满足 G G G 中存在一条边 ( u , u ) (u,u) (u,u) ,即原图为树。
考虑到答案中存在乘方操作,不容易计算,于是令 c n t i cnt_i cnti 表示树上长度为 i i i 的路径的个数,那么答案就是 a n s = ∑ i = 1 n c n t i × i k ans=\sum_{i=1}^ncnt_i\times i^k ans=∑i=1ncnti×ik。不难发现, c n t cnt cnt 可以使用点分治维护,然后用多项式算法优化复杂度,于是这个部分分的问题就解决了。
算法4
数据满足 1 ≤ n ≤ 1 0 5 , 1 ≤ k ≤ 1 0 9 1\le n\le 10^5,1\le k\le 10^9 1≤n≤105,1≤k≤109,保证给定的图 G G G 满足题中要求,且不存在重边。
如果原图是树的话,直接用上面的算法3就好了。
如果是基环树的话,考虑基环DP,先做环以外的点的树形DP,然后在换上合并即可。
参考代码:(代码格式化 Powered by Libre OJ)
#include <iostream>
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int mod = 998244353, inv2 = 499122177, inv3 = 332748118;
vector<int> ve[500005], cnt[500005];
int n, rev[500005], cy[500005], dep[500005], fa[500005], sz[500005], vis[500005];
long long f[500005], ans[500005], a[500005], b[500005];
long long qpow(long long a, int b = mod - 2) {
long long rtv = 1;
for (a %= mod; b; b >>= 1, a = a * a % mod)
if (b & 1)
rtv = rtv * a % mod;
return rtv;
}
inline void ntt1(long long* f, int n) {
for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1;
for (register int i = 0; i < n; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
for (register int i = 1; i < n; i <<= 1) {
long long w = qpow(3, mod / (i << 1));
for (register int j = 0; j < n; j += i << 1) {
long long o = 1;
for (register int k = 0; k < i; ++k, o = o * w % mod) {
long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
}
}
}
return;
}
inline void ntt2(long long* f, int n) {
for (register int i = 0; i < n; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
for (register int i = 1; i < n; i <<= 1) {
long long w = qpow(inv3, mod / (i << 1));
for (register int j = 0; j < n; j += i << 1) {
long long o = 1;
for (register int k = 0; k < i; ++k, o = o * w % mod) {
long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
}
}
}
long long _ = qpow(n);
for (register int i = 0; i < n; ++i) f[i] = f[i] * _ % mod;
return;
}
inline void ntt3(long long* f, long long* g, int n) {
for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1;
for (register int i = 0; i < n; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]), swap(g[i], g[rev[i]]);
for (register int i = 1; i < n; i <<= 1) {
long long w = qpow(3, mod / (i << 1));
for (register int j = 0; j < n; j += i << 1) {
long long o = 1;
for (register int k = 0; k < i; ++k, o = o * w % mod) {
long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
tmp1 = g[j + k], tmp2 = g[i + j + k] * o % mod;
g[j+k]=(tmp1+tmp2)%mod,g[i+j+k]=(tmp1-tmp2+mod)%mod;
}
}
}
return;
}
int deg[500005];
queue<int> q;
inline void bfs(void) {
while (!q.empty()) q.pop();
for (register int i = 1; i <= n; ++i)
if ((deg[i] = ve[i].size()) == 1)
q.push(i);
while (!q.empty()) {
for (register int i : ve[q.front()])
if (--deg[i] == 1)
q.push(i);
q.pop();
}
int s = 1, cur, lst = 0;
while (deg[s] < 2) ++s;
cur = s;
while (1) {
cy[cy[500003]] = cur, ++cy[500003];
for (register int i : ve[cur]) {
if (i ^ lst && deg[i] > 1) {
lst = cur, cur = i;
break;
}
}
if (cur == s)
break;
}
return;
}
int qu[500005], _h, _t;
inline int grt(int s) {
qu[_h = _t = 1] = s, dep[s] = fa[s] = 0;
while (_h <= _t) {
int cur = qu[_h];
sz[cur] = 1;
for (register int i : ve[cur])
if (!vis[i] && i ^ fa[cur])
dep[i] = dep[cur] + 1, fa[qu[++_t] = i] = cur;
++_h;
}
for (register int i = _t; i; --i) {
if (sz[qu[i]] >= _t + 1 >> 1)
return qu[i];
sz[fa[qu[i]]] += sz[qu[i]];
}
}
inline void calc1(int root, int fact, int dis) {
grt(root);
for (register int i = 1; i <= _t; ++i) ++f[dep[qu[i]] += dis];
int len = 1;
while (len <= _t << 1) len <<= 1;
ntt1(f, len);
for (register int i = 0; i < len; ++i) f[i] = f[i] * f[i] % mod;
ntt2(f, len);
for (register int i = 1; i <= _t; ++i) --f[dep[qu[i]] << 1];
fact = 1LL * fact * inv2 % mod;
for (int i=0;i<=min(n,_t<<1);++i) ans[i+1]=(ans[i+1]+f[i]*fact%mod)%mod;
for (register int i = 0; i < len; ++i) f[i] = 0;
return;
}
void treedp(int root) {
vis[root = grt(root)] = 1;
calc1(root, 1, 0);
for (register int i : ve[root])
if (!vis[i])
calc1(i, -1, 1);
for (register int i : ve[root])
if (!vis[i])
treedp(i);
return;
}
inline void gdp(void) {
for (register int i = 1; i <= n; ++i) vis[i] = 0;
for (register int i = 0; i < cy[500003]; ++i) {
vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1;
vis[cy[i]] = 0, treedp(cy[i]);
}
return;
}
inline void mul(int len) {
ntt3(a, b, len);
for (register int i = 0; i < len; ++i) a[i] = a[i] * b[i] % mod;
ntt2(a, len);
return;
}
inline void calc2(int l, int r, int ql, int qr) {
if (ql > qr)
return;
int n = 0, m = 0;
for (register int i = l; i <= r; ++i)
for (register int j = 0; j < cnt[i].size(); ++j) {
n = max(n, r - i + 1 + j), a[r - i + 1 + j] += cnt[i][j];
}
for (register int i = ql; i <= qr; ++i)
for (register int j = 0; j < cnt[i].size(); ++j) {
m = max(m, i - ql + 1 + j), b[i - ql + 1 + j] += cnt[i][j];
}
int len = 1;
while (len <= n + m) len <<= 1;
mul(len);
for (register int i=0;i<=n+m;++i) ans[i+ql-r-1]=(ans[i+ql-r-1]+a[i])%mod;
for (register int i = 0; i < len; ++i) a[i] = b[i] = 0;
return;
}
inline void calc3(int l, int r, int ql, int qr) {
if (ql > qr)
return;
int n = 0, m = 0;
for (register int i = l; i <= r; ++i)
for (register int j = 0; j < cnt[i].size(); ++j) {
n = max(n, i - l + 1 + j), a[i - l + 1 + j] += cnt[i][j];
}
for (register int i = ql; i <= qr; ++i)
for (register int j = 0; j < cnt[i].size(); ++j) {
m = max(m, qr - i + 1 + j), b[qr - i + 1 + j] += cnt[i][j];
}
int len = 1;
while (len <= n + m) len <<= 1;
mul(len);
for (register int i = 0; i <= n + m; ++i)
ans[i+l+cy[500003]-1-qr]=(ans[i+l+cy[500003]-1-qr]+a[i])%mod;
for (register int i = 0; i < len; ++i) a[i] = b[i] = 0;
return;
}
void ringdp(int l, int r) {
if (l == r)
return;
int mid = l + r >> 1;
calc2(mid + 1, r, min(cy[500003] - 1, max(r, l + cy[500003] / 2)) + 1,
min(cy[500003] - 1, mid + cy[500003] / 2 + 1));
calc2(l, mid, mid + 1, min(r, l + cy[500003] / 2));
calc3(l,mid,mid+cy[500003]/2+1,min(cy[500003]-1,r+cy[500003]/2));
ringdp(l, mid), ringdp(mid + 1, r);
return;
}
void sol(void) {
for (register int i = 1; i <= n; ++i) vis[i] = 0;
for (register int i = 0; i < cy[500003]; ++i) {
vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1;
vis[cy[i]] = 0, grt(cy[i]), cnt[i].resize(_t);
for (register int j = 0; j < _t; ++j) cnt[i][j] = 0;
for (register int j = 1; j <= _t; ++j) ++cnt[i][dep[qu[j]]];
}
ringdp(0, cy[500003] - 1);
return;
}
int main() {
int k, m, u, v;
scanf("%d%d", &n, &k), m = n;
for (register int i = 1; i <= n; ++i) ve[i].clear();
for (register int i = 1; i <= n; ++i) {
scanf("%d%d", &u, &v);
if (u ^ v)
ve[u].push_back(v), ve[v].push_back(u);
else
--m;
}
if (m == n) {
bfs();
for (register int i = 1; i <= n; ++i) ans[i] = 0;
gdp(), sol();
} else
treedp(1);
for (register int i=1;i<=n;++i) ans[0]=(ans[0]+qpow(i,k)*ans[i+1]%mod)%mod;
printf("%lld", (ans[0] * qpow(1LL * n * (n - 1) >> 1) % mod + mod) % mod);
return 0;
}