题目大意:
迷途竹林可以看成是一个n个点的森林,幽幽子定义dis(u,v)为u到v路径上的边的数量,若u和v不连通则为m。她定义整个森林的危险度为
为了去拜访永琳师匠,幽幽子需要提前知道迷途竹林的危险度。但迷途竹林的形态是时刻变化着的,所以幽幽子希望知道危险度的期望是多少。
为了避免浮点运算,答案对998244353取模。
1<=n<=2e5
题解:
这个计数写得真的是累。
很容易想到设:
f
[
i
]
f[i]
f[i]表示i个点生成树个数
g
[
i
]
g[i]
g[i]表示i个点生成森林计数
c
[
i
]
c[i]
c[i]表示i个点生成树的距离平方和。
只要能求出这三个东西,再随便卷卷就能搞出来答案了。
f
[
0
]
=
f
[
1
]
=
1
,
f
[
i
]
=
i
i
−
2
(
i
>
1
)
f[0]=f[1]=1,f[i]=i^{i-2}(i>1)
f[0]=f[1]=1,f[i]=ii−2(i>1)
g
=
e
f
g=e^f
g=ef,一些阶乘忽略掉了
copy个版还算轻松。
然后就是求c。
思路为距离平方和拆为中间的有序点对数*2-点的个数+1。
那么考虑枚举这两个有序点对,大概是三个数组卷起来。
再枚举点,大概是两个数组卷起来。
然后就愉快的解决了这题。
Code:
#include<vector>
#include<cstdio>
#include<algorithm>
#define ll long long
#define pp printf
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define pb push_back
using namespace std;
typedef vector<ll> V;
const int N = (1 << 21) + 5;
const int mo = 998244353;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
int r[N]; ll aa[N];
void dft(V &b, int f) {
int n = b.size();
ff(i, 0, n) aa[i] = b[i];
ff(i, 0, n) {
r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
if(i < r[i]) swap(aa[i], aa[r[i]]);
}
for(int h = 1; h < n; h *= 2) {
ll wn = ksm(ksm(3, (mo - 1) / 2 / h), f == -1 ? mo - 2 : 1);
for(int j = 0; j < n; j += 2 * h) {
ll A, W = 1, *l = aa + j, *r = aa + j + h;
ff(i, 0, h) {
A = *r * W, *r = (*l - A) % mo, *l = (*l + A) % mo;
W = W * wn % mo; l ++; r ++;
}
}
}
if(f == -1) {
ll v = ksm(n, mo - 2);
ff(i, 0, n) aa[i] = (aa[i] + mo) * v % mo;
}
ff(i, 0, n) b[i] = aa[i];
}
V operator * (V a, V b) {
int z = a.size() + b.size() - 1;
int n = 1; while(n < z) n *= 2;
a.resize(n); b.resize(n);
dft(a, 1); dft(b, 1);
ff(i, 0, n) a[i] = a[i] * b[i] % mo;
dft(a, -1); a.resize(z); return a;
}
V operator - (V a, V b) {
if(a.size() < b.size()) a.resize(b.size());
ff(i, 0, a.size()) a[i] = (a[i] - b[i] + mo) % mo;
return a;
}
V qni(V a) {
int n0 = 1; while(n0 < a.size()) n0 *= 2;
V b; b.clear(); b.pb(ksm(a[0], mo - 2));
for(int n = 2; n <= n0; n *= 2) {
b.resize(n * 2); dft(b, 1);
V c = a; c.resize(n); c.resize(n * 2); dft(c, 1);
ff(i, 0, n * 2) b[i] = (b[i] * 2 - b[i] * b[i] % mo * c[i]) % mo;
dft(b, -1); b.resize(n);
}
b.resize(a.size()); return b;
}
V qd(V a) {
a[0] = 0;
ff(i, 1, a.size()) a[i - 1] = a[i] * i % mo;
return a;
}
V jf(V a) {
a.pb(0);
fd(i, a.size(), 1) a[i] = a[i - 1] * ksm(i, mo - 2) % mo;
a[0] = 0;
return a;
}
V ln(V a) {
int sa = a.size();
V b = a; b = qni(b); a = qd(a);
a = a * b; a = jf(a); a.resize(sa);
return a;
}
V exp(V a) {
int n0 = 1; while(n0 < a.size()) n0 *= 2;
V b; b.clear(); b.pb(1);
for(int n = 2; n <= n0; n *= 2) {
V c = b; c.resize(n); c = ln(c);
V d = a; d.resize(n);
c = c - d;
c.resize(2 * n); dft(c, 1);
b.resize(2 * n); dft(b, 1);
ff(i, 0, 2 * n) b[i] = (b[i] - c[i] * b[i]) % mo;
dft(b, -1); b.resize(n);
}
b.resize(a.size());
return b;
}
int T, n, m;
ll fac[N], nf[N], f[N];
int n0;
V a, b;
ll c[N], d[N], ans1[N], ans2[N];
ll ni2;
int main() {
n = 2e5;
fac[0] = 1; fo(i, 1, n) fac[i] = fac[i - 1] * i % mo;
nf[n] = ksm(fac[n], mo - 2); fd(i, n, 1) nf[i - 1] = nf[i] * i % mo;
f[0] = f[1] = 1; fo(i, 2, n) f[i] = ksm(i, i - 2);
n0 = 1; while(n0 <= n) n0 *= 2;
a.resize(4 * n0); b.resize(4 * n0);
fo(i, 1, n) a[i] = b[i] = f[i] * i % mo * i % mo * nf[i] % mo;
b[0] = 1;
dft(a, 1); dft(b, 1);
ff(i, 0, a.size()) a[i] = a[i] * a[i] % mo * b[i] % mo;
dft(a, -1);
fo(i, 1, n) c[i] = a[i] * fac[i] * 2 % mo;
a.clear(); a.resize(2 * n0);
b.clear(); b.resize(2 * n0);
fo(i, 1, n) a[i] = f[i] * i % mo * i % mo * nf[i] % mo;
dft(a, 1);
ff(i, 0, a.size()) a[i] = a[i] * a[i] % mo;
dft(a, -1);
fo(i, 1, n) c[i] = (c[i] - a[i] * fac[i]) % mo;
a.clear(); a.resize(n + 1);
fo(i, 1, n) a[i] = f[i] * nf[i] % mo;
a = exp(a);
fo(i, 0, n) d[i] = a[i] * fac[i] % mo;
a.clear(); a.resize(n + 1);
fo(i, 1, n) a[i] = c[i] * nf[i] % mo;
b.clear(); b.resize(n + 1);
fo(i, 0, n) b[i] = d[i] * nf[i] % mo;
a = a * b;
fo(i, 1, n) ans1[i] = a[i] * fac[i] % mo;
a.clear(); a.resize(n + 1);
fo(i, 1, n) a[i] = f[i] * nf[i] % mo * i % mo;
b.clear(); b.resize(n + 1);
fo(i, 1, n) b[i] = d[i] * nf[i] % mo * i % mo;
a = a * b;
fo(i, 1, n) ans2[i] = a[i] * fac[i] % mo;
ni2 = ksm(2, mo - 2);
for(scanf("%d", &T); T; T --) {
scanf("%d %d", &n, &m);
pp("%lld\n", (ans1[n] + ans2[n] * m % mo * m) % mo * ksm(d[n], mo - 2) % mo * ni2 % mo);
}
}