首先这道题不是二叉树也可以做
题目要求一个点的权值必须比所有儿子都大,于是可以转换为对每一个点分配一个 [ 0 , x ] [0,x] [0,x] 的权值,最终权值是子树的和
这个的方案数只与集合大小有关
为
(
x
+
S
−
1
S
−
1
)
\binom{x+S-1}{S-1}
(S−1x+S−1)
于是问题转换为统计树上大小为
k
k
k 的连通块个数
已经可以
n
2
n^2
n2
d
p
dp
dp 了
考虑到一个
d
p
dp
dp 的最大维度是
s
i
z
e
[
u
]
size[u]
size[u],可以
d
s
u
o
n
t
r
e
e
dsu\ on\ tree
dsu on tree
所有轻儿子的
s
i
z
e
size
size 和是
n
l
o
g
(
n
)
nlog(n)
nlog(n) 的
于是在每层结点可以暴力合并轻儿子,合并用分治
F
F
T
FFT
FFT,这一部分的复杂度是
O
(
n
l
o
g
(
n
)
3
)
O(nlog(n)^3)
O(nlog(n)3)
考虑重链如何合并,令
f
u
f_u
fu 为最终的答案,
g
u
g_u
gu 为只考虑轻儿子的答案
(这里的
f
,
g
f,g
f,g 是写成生成函数过后的)
f
u
=
g
u
(
f
u
+
1
+
1
)
=
g
u
(
(
g
u
+
1
(
f
u
+
2
+
1
)
+
1
)
f_u=g_u(f_{u+1}+1)=g_u((g_{u+1}(f_{u+2}+1)+1)
fu=gu(fu+1+1)=gu((gu+1(fu+2+1)+1)
考虑到当没有重儿子的时候
f
=
g
f=g
f=g
于是一个点
u
u
u 的
f
f
f 可以看做
g
u
+
g
u
∗
g
u
+
1
+
g
u
∗
g
u
+
1
∗
g
u
+
2
+
.
.
.
g_u+g_u*g_{u+1}+g_u*g_{u+1}*g_{u+2}+...
gu+gu∗gu+1+gu∗gu+1∗gu+2+... 可以分治
F
F
T
FFT
FFT
这一部分的复杂度任然是
n
l
o
g
(
n
)
3
nlog(n)^3
nlog(n)3,最后用组合数算一下就可以了
#include<bits/stdc++.h>
#define cs const
using namespace std;
typedef long long ll;
cs int N = 1e5 + 5;
cs int Mod = 998244353;
int add(int a, int b){ return a + b >= Mod ? a + b - Mod : a + b; }
int mul(int a, int b){ return 1ll * a * b % Mod; }
int ksm(int a, int b){ int ans = 1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans = mul(ans, a); return ans; }
void Mul(int &a, int b){ a = mul(a, b); }
void Add(int &a, int b){ a = add(a, b); }
int dec(int a, int b){ return a - b < 0 ? a - b + Mod : a - b; }
#define pb push_back
#define poly vector<int>
poly a[N]; int inv[N];
void Cout(poly tt){
for(int i = 0; i <tt.size(); i++) cout <<tt[i] <<" ";
cout <<'\n';
}
namespace Poly{
int bit, up, rev[N << 2];
void init(int deg){
up = 1; bit = 0; while(up <= deg) up <<= 1, ++bit;
for(int i = 0; i < up; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<bit-1);
}
cs int C = 19;
poly w[C+1];
void prework(){
for(int i = 1; i <= C; i++) w[i].resize(1<<i-1);
int wn = ksm(3, (Mod-1)/(1<<C)); w[C][0] = 1;
for(int i = 1; i < (1<<C-1); i++) w[C][i] = mul(w[C][i-1], wn);
for(int i = C-1; i; i--) for(int j = 0; j < (1<<i-1); j++) w[i][j] = w[i+1][j<<1];
}
void NTT(poly &a, int typ){
for(int i = 0; i < up; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int l = 1, i = 1; i < up; i <<= 1, ++l)
for(int j = 0; j < up; j += (i<<1))
for(int k = 0; k < i; k++){
int x = a[k+j], y = mul(w[l][k], a[k+j+i]);
a[k+j] = add(x, y); a[k+j+i] = dec(x, y);
}
if(typ == -1){
reverse(a.begin() + 1, a.end());
for(int i = 0, inv = ksm(up, Mod-2); i < up; i++) a[i] = mul(a[i], inv);
}
}
poly operator * (poly a, poly b){
int deg = a.size() + b.size() - 1;
init(deg); a.resize(up); b.resize(up);
NTT(a, 1); NTT(b, 1);
for(int i = 0; i < up; i++) Mul(a[i], b[i]);
NTT(a, -1); a.resize(deg); return a;
}
poly operator + (poly a, poly b){
int deg = max(a.size(), b.size()); a.resize(deg); b.resize(deg);
for(int i = 0; i < deg; i++) a[i] = add(a[i], b[i]); return a;
}
void Solve(int l, int r, poly &F, poly &G){
if(l == r){ F = G = a[l]; return; } int mid = (l+r) >> 1;
poly Fl, Fr, Gl, Gr;
Solve(l, mid, Fl, Gl); Solve(mid+1, r, Fr, Gr);
F = Fl * Fr; G = Gl + Fl * Gr;
}
}
int n; ll K;
vector<int> v[N];
int sz[N], son[N], fa[N];
void pre_dfs(int u, int f){
sz[u] = 1;
for(int i = 0; i < v[u].size(); i++){
int t = v[u][i]; if(t == f) continue;
fa[t] = u; pre_dfs(t, u); sz[u] += sz[t];
if(sz[t] > sz[son[u]]) son[u] = t;
}
}
poly f[N];
poly calc(int u){
for(int t = u; t; t = son[t]){
for(int i = 0; i < v[t].size(); i++) if((v[t][i] ^ fa[t]) && (v[t][i] ^ son[t])) f[t] = calc(v[t][i]);
if(f[t].empty()) f[t].pb(0); ++f[t][0]; f[t].insert(f[t].begin(), 0);
}
int k = 0;
for(int t = u; t; t = son[t]) swap(a[++k], f[t]);
poly F, G; Poly::Solve(1, k, F, G); return G;
}
int main(){
scanf("%d%lld", &n, &K); Poly::prework();
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i++) inv[i] = mul(Mod-Mod/i, inv[Mod%i]);
for(int i = 1; i < n; i++){
int x, y; scanf("%d%d", &x, &y); v[x].pb(y); v[y].pb(x);
} pre_dfs(1, 0); poly g = calc(1);
int Binom = 1, ans = 0;
for(int i = 1; i < g.size(); i++){
Add(ans, mul(Binom, g[i])); Mul(Binom, mul((K+i)%Mod, inv[i]));
} cout << ans; return 0;
}