codeforces981H. K Paths
分析
题目大意:树上选 k k k条路径,要求选择之后某条边只能被经过 0 , 1 , k 0,1,k 0,1,k次,且不能没有经过 k k k次的边。求方案数。
所有被经过
k
k
k次的边形成的一定是一条树上的路径,考虑枚举路径的两个点
u
,
v
u,v
u,v。考虑
u
u
u子树的端点选取。要么放在
u
u
u上,要么从
u
u
u的儿子的子树挑一个点。注意一个子树只能挑一个点。
那么每个子树可以挑或者不挑,生成函数之后就是
c
(
x
)
=
∏
(
1
+
s
z
s
o
n
u
x
)
c(x)=\prod (1+sz_{son_u}x)
c(x)=∏(1+szsonux)
假设挑了
i
i
i个点在子树里面,那么自然有
A
k
i
A_k^i
Aki的方案书把这些节点分配给
k
k
k个端点。
说白了就是
A
n
s
u
=
∑
A
k
i
c
(
i
)
Ans_u=\sum A_k^ic(i)
Ansu=∑Akic(i)
两边都可以这么选,所以总的方案相当于是把两边的
A
n
s
Ans
Ans乘起来。
预处理出
A
n
s
u
Ans_u
Ansu,如果两个点不是祖先关系,直接乘起来,相当于是
1
2
[
(
∑
A
n
s
u
)
2
−
∑
A
n
s
u
2
]
\frac{1}{2}[(\sum Ans_u)^2-\sum Ans_u^2]
21[(∑Ansu)2−∑Ansu2]
这个时候还要再扣掉祖先关系的贡献,也就是
∑
A
n
s
u
g
u
\sum Ans_u g_u
∑Ansugu,其中
g
u
g_u
gu表示子树的
A
n
s
Ans
Ans和。
祖孙关系怎么处理?对于某个子树方向的子节点
v
v
v,贡献是
c
(
x
)
1
+
(
n
−
s
z
[
u
]
)
x
1
+
s
z
[
v
]
x
c(x)\frac{1+(n-sz[u])x}{1+sz[v]x}
c(x)1+sz[v]x1+(n−sz[u])x
可是每个子节点这个东西的处理是
O
(
d
)
O(d)
O(d),其中
d
d
d是度数。
但是考虑把相同子树大小的子节点合并处理,每个节点的子节点的不同的
s
z
sz
sz只有
O
(
n
)
O(\sqrt n)
O(n)
复杂度就是
O
(
n
l
o
g
n
+
n
n
)
O(nlogn+n\sqrt n)
O(nlogn+nn)
代码
#include<bits/stdc++.h>
const int N = 4e5 + 10, P = 998244353;
typedef std::vector<int> VI;
int ri() {
char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int A[N], B[N], R[N], w[N], pr[N], to[N << 1], nx[N << 1], tp;
int fac[N], ivf[N], res[N], f[N], g[N], h[N], sz[N], val[N], fa[N], n, k, L, IvL;
VI c, d; long long ans;
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
int fixd(int x) {return x < 0 ? x + P : x;}
int fixu(int x) {return x >= P ? x - P : x;}
void Inc(int &a, int b) {a = fixu(a + b);}
int Pow(int x, int k) {
int r = 1;
for(;k; x = 1LL * x * x % P, k >>= 1)
if(k & 1)
r = 1LL * r * x % P;
return r;
}
int Iv(int x) {return Pow(x, P - 2);}
void Pre(int m) {
L = 1; int x = 0;
for(;(L <<= 1) < m;) ++x;
for(int i = 1;i < L; ++i)
R[i] = R[i >> 1] >> 1 | (i & 1) << x;
int wn = Pow(3, (P - 1) / L); w[0] = 1;
for(int i = 1;i < L; ++i)
w[i] = 1LL * w[i - 1] * wn % P;
IvL = Iv(L);
}
void NTT(int *F) {
for(int i = 1;i < L; ++i)
if(i < R[i])
std::swap(F[i], F[R[i]]);
for(int i = 1, d = L >> 1; i < L; i <<= 1, d >>= 1)
for(int j = 0;j < L; j += i << 1) {
int *l = F + j, *r = F + i + j, *p = w, tp;
for(int k = i; k--; ++l, ++r, p += d)
tp = 1LL * *r * *p % P, *r = fixd(*l - tp), *l = fixu(*l + tp);
}
}
void Get(VI a, int *A) {
for(int i = 0;i < a.size(); ++i)
A[i] = a[i];
}
VI operator * (VI a, VI b) {
VI c; int n = a.size() + b.size() - 1;
Pre(n);
Get(a, A); Get(b, B);
for(int i = a.size(); i < L; ++i)
A[i] = 0;
for(int i = b.size(); i < L; ++i)
B[i] = 0;
NTT(A); NTT(B);
for(int i = 0;i < L; ++i)
A[i] = 1LL * A[i] * B[i] % P;
NTT(A);
for(int i = 0;i < n; ++i)
c.push_back(1LL * A[L - i & L - 1] * IvL % P);
return c;
}
void operator *= (VI &a, int v) {
a.push_back(0);
for(int i = a.size() - 1;i; --i)
Inc(a[i], 1LL * a[i - 1] * v % P);
}
void operator /= (VI &a, int v) {
static int tp[N]; int n = a.size(), iv = Iv(v); tp[n - 1] = 0;
for(int i = n - 1; i; --i)
tp[i - 1] = 1LL * fixd(a[i] - tp[i]) * iv % P;
a.pop_back();
for(int i = 0;i < n - 1; ++i)
a[i] = tp[i];
}
VI Solve(int L, int R) {
if(L == R) return (VI){1, val[L]}; int m = L + R >> 1;
return Solve(L, m) * Solve(m + 1, R);
}
int PA(int m, int n) {return 1LL * fac[m] * ivf[m - n] % P;}
int Calc(VI a) {
int r = std::min((int)a.size(), k + 1); long long res = 0;
for(int i = 0;i < r; ++i)
res += 1LL * PA(k, i) * a[i] % P;
return res % P;
}
void Dfs(int u, int fa) {
sz[u] = 1; ::fa[u] = fa;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa) {
Dfs(to[i], u);
Inc(g[u], g[to[i]]);
sz[u] += sz[to[i]];
}
int tp = 0;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa)
val[++tp] = sz[to[i]];
if(!tp)
f[u] = 1, Inc(g[u], 1);
else {
std::sort(val + 1, val + tp + 1);
c = Solve(1, tp);
f[u] = Calc(c); Inc(g[u], f[u]);
tp = std::unique(val + 1, val + tp + 1) - val - 1;
for(int i = 1;i <= tp; ++i) {
d = c; d /= val[i]; d *= n - sz[u];
res[val[i]] = Calc(d);
}
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa)
h[to[i]] = res[sz[to[i]]];
}
}
void pre(int n) {
fac[0] = 1;
for(int i = 1;i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % P;
ivf[n] = Iv(fac[n]);
for(int i = n; i; --i)
ivf[i - 1] = 1LL * ivf[i] * i % P;
}
int main() {
n = ri(); k = ri();
if(k == 1) return printf("%lld\n", (1LL * n * (n - 1) >> 1) % P), 0;
pre(std::max(n, k));
for(int i = 1;i < n; ++i)
adds(ri(), ri());
Dfs(1, 0);
for(int i = 1;i <= n; ++i)
ans += f[i];
ans %= P; ans = ans * ans;
for(int i = 1;i <= n; ++i)
ans -= 1LL * f[i] * f[i] % P;
ans %= P; ans = ans * (P + 1 >> 1);
for(int u = 1;u <= n; ++u)
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa[u])
ans += 1LL * (h[to[i]] - f[u]) * g[to[i]] % P;
printf("%d\n", fixd(ans % P));
return 0;
}