【题目】
CF
给定一棵
n
n
n个节点的树,你需要按顺序选择
k
k
k条路径(可以相同,先后顺序不同方案不同),使得每一条边要么不被覆盖,要么仅被一条路径覆盖,要么被所有
k
k
k条路径覆盖。求方案数模
998244353
998244353
998244353。
【解题思路】
首先考虑暴力,我们枚举一条路径,实际上就是要在两个节点的子树中分别选择
k
k
k个点,同时每个儿子子树中只能选择一个点,但根节点本身可以选择任意次。
于是现在单独考虑一个点怎么处理出这个东西,设它为
f
f
f,则生成函数就是
∑
i
=
1
m
(
s
i
z
s
o
n
i
+
1
)
\sum_{i=1}^m(siz_{son_i}+1)
i=1∑m(sizsoni+1)
其中
s
i
z
siz
siz表示子树大小,
s
o
n
son
son表示儿子节点。设这个东西的
x
i
x^i
xi项系数为
a
i
a_i
ai,则有:
f
x
=
∑
i
=
0
m
a
i
⋅
P
k
i
f_x=\sum_{i=0}^m a_i\cdot P_{k}^i
fx=i=0∑mai⋅Pki
不考虑祖先关系,则现在的答案就是:
1
2
(
(
∑
f
x
)
2
−
∑
f
x
2
)
\frac 1 2 ((\sum f_x)^2-\sum f_x^2)
21((∑fx)2−∑fx2)
上面这部分可以用分治
FFT
\text{FFT}
FFT解决。
还要考虑有祖先关系的点对贡献,那么考虑在较浅的节点处进行计算,设其为 v v v,若选择了 v v v的一个儿子 u u u子树中的节点座位另一个端点,那么实际上 v v v对应的生成函数就要乘上 1 + ( n − s i z v ) x 1 + ( s i z u ) x \frac {1+(n-siz_v)x} {1+(siz_u)x} 1+(sizu)x1+(n−sizv)x
乘或除以一个二项式的时间都是 O ( n ) O(n) O(n)的,观察到对于子树大小相同的孩子其贡献多项式是一样的,可以一起计算,那么这个总个数是 O ( n ) O(\sqrt n) O(n)级别的。
于是最后复杂度就是 O ( n log 2 n + n n ) O(n\log ^2n +n\sqrt n) O(nlog2n+nn)了。
【参考代码】
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int N=262333,mod=998244353,G=3,inv2=(mod+1)>>1;
int read()
{
int ret=0;char c=getchar();
while(!isdigit(c)) c=getchar();
while(isdigit(c)) ret=ret*10+(c^48),c=getchar();
return ret;
}
namespace Math
{
int fac[N],ifac[N],inv[N];
int add(int x){return x>=mod?x-mod:x;}
int sub(int x){return x<0?x+mod:x;}
void Add(int &x,int y){x=add(x+y);}
void Sub(int &x,int y){x=sub(x-y);}
int mul(int x,int y){return 1ll*x*y%mod;}
int qpow(int x,int y){int res=1;for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);return res;}
int getinv(int x){return qpow(x,mod-2);}
void initmath()
{
fac[0]=1;for(int i=1;i<N;++i)fac[i]=mul(fac[i-1],i);
ifac[N-1]=getinv(fac[N-1]);for(int i=N-2;~i;--i)ifac[i]=mul(ifac[i+1],i+1);
inv[0]=inv[1]=1;for(int i=2;i<N;++i)inv[i]=mul(mod-mod/i,inv[mod%i]);
}
int P(int x,int y){return 1ll*fac[x]*ifac[x-y]%mod;}
}
using namespace Math;
namespace Poly
{
int m,L,rev[N];
void ntt(int *a,int n,int op)
{
for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
int wn=qpow(G,(mod-1)/(i<<1));
if(!~op) wn=getinv(wn);
for(int j=0;j<n;j+=i<<1)
{
int w=1;
for(int k=0;k<i;++k,w=mul(w,wn))
{
int x=a[j+k],y=mul(w,a[i+j+k]);
a[j+k]=add(x+y);a[i+j+k]=sub(x-y);
}
}
}
if(!~op)for(int i=0,iv=getinv(n);i<n;++i)a[i]=mul(iv,a[i]);
}
void reget(int n)
{
for(m=1,L=0;m<n;m<<=1,++L);
for(int i=0;i<m;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
}
void polymul(int *a,int *b,int *c)
{
//for(int i=0;i<m;++i) printf("%d ",a[i]); puts("");
//for(int i=0;i<m;++i) printf("%d ",b[i]); puts("");
ntt(a,m,1);ntt(b,m,1);
for(int i=0;i<m;++i) c[i]=mul(a[i],b[i]);
ntt(c,m,-1);
//for(int i=0;i<m;++i) printf("%d ",c[i]); puts("");
}
void polymult(int *a,int *b,int *c,int dega,int degb)
{
static int A[N],B[N];
reget(dega+degb-1);copy(a,a+dega,A);copy(b,b+degb,B);
//printf("degs:%d %d %d\n",dega,degb,m);
polymul(A,B,c);
//for(int i=0;i<m;++i) printf("%d ",c[i]); puts("");
fill(c+dega+degb-1,c+m,0);fill(A,A+m,0);fill(B,B+m,0);
}
void polydec(int *a,int deg,int v)
{
static int A[N];
int coe=getinv(v),iv;
for(int i=0;i<deg;++i) A[i]=a[i],a[i]=0;
for(int i=deg-1;i;--i)
{
if(A[i])
{
a[i-1]=iv=mul(A[i],coe);
Sub(A[i],mul(iv,coe));Sub(A[i-1],iv);
}
}
fill(A,A+deg,0);
}
void polyadd(int *a,int deg,int v)
{
for(int i=deg;i;--i) Add(a[i],mul(a[i-1],v));
}
}
using namespace Poly;
namespace DreamLolita
{
int n,K,tot,ans;
int head[N],siz[N],now[N],val[N],tmp[N];
int f[N],g[N],h[N],F[20][N];
struct Tway{int v,nex;}e[N];
void add(int u,int v)
{
e[++tot]=(Tway){v,head[u]};head[u]=tot;
e[++tot]=(Tway){u,head[v]};head[v]=tot;
}
void solve(int l,int r,int d)
{
if(l==r){F[d][1]=val[l];F[d][0]=1;return;}
int mid=(l+r)>>1;
solve(l,mid,d);solve(mid+1,r,d+1);
polymult(F[d],F[d+1],F[d],mid-l+2,r-mid+1);
fill(F[d+1],F[d+1]+Poly::m,0);
}
int calc(int *a,int len)
{
int res=0,lim=min(len,K);
for(int i=0;i<=lim;++i) Add(res,mul(a[i],P(K,i)));
return res;
}
bool cmp(int x,int y){return siz[x]<siz[y];}
void dfs1(int x,int fa)
{
siz[x]=1;int son=0;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v;
if(v==fa) continue;
dfs1(v,x);Add(g[x],g[v]);siz[x]+=siz[v];
}
for(int i=head[x];i;i=e[i].nex) if(e[i].v!=fa) val[++son]=siz[e[i].v];;
if(!son) {f[x]=g[x]=1;return;}
solve(1,son,0);
//printf("%d:\n",x);
//for(int i=1;i<=son;++i) printf("%d!",val[i]); puts("");
//for(int i=0;i<=son;++i) printf("%d ",F[0][i]); puts("");
f[x]=calc(F[0],son);Add(g[x],f[x]);son=0;
//printf("%d:%d\n",x,f[x]);
for(int i=head[x];i;i=e[i].nex) if(e[i].v!=fa) val[++son]=e[i].v;
sort(val+1,val+son+1,cmp);
for(int i=0;i<=son;++i) tmp[i]=F[0][i];
for(int i=1;i<=son;++i)
{
if(siz[val[i]]==siz[val[i-1]]) h[val[i]]=h[val[i-1]];
else
{
for(int j=0;j<=son;++j) now[j]=tmp[j];
polydec(now,son+1,siz[val[i]]);polyadd(now,son,n-siz[x]);
h[val[i]]=calc(now,son);
}
//printf("%d %d\n",val[i],h[val[i]]);
}
fill(F[0],F[0]+Poly::m,0);
for(int i=0;i<=son;++i) now[i]=tmp[i]=0;
}
void dfs2(int x,int fa)
{
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v;
if(v==fa) continue;
Add(ans,mul(sub(h[v]-f[x]),g[v]));dfs2(v,x);
}
}
void solution()
{
initmath();n=read();K=read();
if(K==1){printf("%lld\n",1ll*n*(n-1)/2%mod);return;}
for(int i=1;i<n;++i) add(read(),read());
dfs1(1,0);dfs2(1,0);
//for(int i=1;i<=n;++i) printf("%d %d %d\n",f[i],g[i],h[i]);
//printf("%d\n",ans);
int sum=0;
for(int i=1;i<=n;++i) Add(sum,f[i]);
sum=mul(sum,sum);
for(int i=1;i<=n;++i) Sub(sum,mul(f[i],f[i]));
sum=mul(sum,inv2);Add(ans,sum);
printf("%d\n",ans);
}
}
int main()
{
#ifdef Durant_Lee
freopen("CF981H.in","r",stdin);
freopen("CF981H.out","w",stdout);
#endif
DreamLolita::solution();
return 0;
}