就是要求定下两个点,并从这两个点往外延伸k个互不相交的分支的方案数
首先求出从某个点出发向子树内部延伸i个分支的方案数
对于一个点的一个儿子,这个儿子的子树中每一个点都可能成为一条分支的终点,这个儿子内就有siz个终点
不同儿子的子树互不影响,所以根据乘法原理,可以构造点v的生成函数:
F
(
x
)
=
∏
(
1
+
s
i
z
s
o
n
[
v
]
)
F(x)=\prod{(1+siz_{son[v]})}
F(x)=∏(1+sizson[v])
同时一个分支的端点可能与点v重合,考虑有x个不重合,则点v的答案就是
∑
x
=
0
k
A
k
x
f
[
i
]
\sum_{x=0}^k{A_k^xf[i]}
∑x=0kAkxf[i],把两两点乘起来统计答案,
F
(
x
)
F(x)
F(x)可以用分治NTT来求
这就是两个点不为祖先-后代的情况
否则假设这两个点为u,v,u为v的祖先,s为u的儿子且为v的祖先
我们设另一个生成函数
G
(
x
)
=
F
(
x
)
∗
1
+
(
n
−
s
i
z
u
)
x
1
+
s
i
z
s
x
G(x)=F(x)*\frac{1+(n-siz_u)x}{1+siz_sx}
G(x)=F(x)∗1+sizsx1+(n−sizu)x,表示u不能向v的方向延伸分支,且可以向上延伸分支(把u的子树以外的部分看作u的另一棵子树),则点u的答案就是
∑
x
=
0
k
A
k
x
g
[
i
]
\sum_{x=0}^k{A_k^xg[i]}
∑x=0kAkxg[i],同样把两两点乘起来统计答案,v的答案用
F
(
x
)
F(x)
F(x)表示
分治NTT是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),后面的枚举 s i z siz siz是 O ( n n ) O(n\sqrt{n}) O(nn),因为一个子树的不同 s i z siz siz个数不超过 n \sqrt{n} n,总复杂度 O ( n l o g 2 n + n n ) O(nlog^2n+n\sqrt{n}) O(nlog2n+nn)
Code:
#include<bits/stdc++.h>
#define poly vector<int>
#define mod 998244353
#define pb push_back
#define ll long long
#define fi first
#define se second
using namespace std;
inline int read(){
int res=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
return res*f;
}
inline int add(int x,int y){x+=y;if(x>=mod) x-=mod;return x;}
inline int dec(int x,int y){x-=y;if(x<0) x+=mod;return x;}
inline int mul(int x,int y){return (ll)x*y%mod;}
inline void inc(int &x,int y){x+=y;if(x>=mod) x-=mod;}
inline int ksm(int a,int b){int res=1;for(;b;b>>=1,a=mul(a,a)) if(b&1) res=mul(res,a);return res;}
const int N=1e5+5;
namespace Ntt{
int lim,t,rev[N<<2];
inline void init_rev(int len){
lim=1,t=0;
while(lim<=len) lim<<=1,++t;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(t-1));
}
inline void ntt(poly &a,int kd){
for(int i=0;i<lim;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
int p=(mod-1)/2,w,wn,g=kd==1?3:(mod+1)/3,a0,a1;
for(int mid=1;mid<lim;mid<<=1,p>>=1){
wn=ksm(g,p);
for(int j=0,len=mid<<1;j<lim;j+=len){
w=1;
for(int k=0;k<mid;k++,w=mul(w,wn)){
a0=a[j+k],a1=mul(a[j+k+mid],w);
a[j+k]=add(a0,a1);a[j+k+mid]=dec(a0,a1);
}
}
}
if(kd==-1) for(int i=0,inv=ksm(lim,mod-2);i<lim;i++) a[i]=mul(a[i],inv);
}
}
using namespace Ntt;
inline void Mul(poly &a,int b){
a.pb(0);
for(int i=a.size()-2;~i;i--) inc(a[i+1],mul(a[i],b));
}
inline void div(poly &a,int b){
int inv=ksm(b,mod-2);
poly t=a;
for(int i=a.size()-1;i;i--) a[i-1]=mul(t[i],inv),t[i-1]=dec(t[i-1],a[i-1]);
a.pop_back();
}
inline poly operator *(poly a,poly b){
int n=a.size()-1,m=b.size()-1;
init_rev(n+m);
a.resize(lim),b.resize(lim);
ntt(a,1),ntt(b,1);
for(int i=0;i<lim;i++) a[i]=mul(a[i],b[i]);
ntt(a,-1);
return a;
}
int tmp[N],top;
inline poly solve(int l,int r){
if(l==r) {poly c;c.pb(1),c.pb(tmp[l]);return c;}
int mid=l+r>>1;
return solve(l,mid)*solve(mid+1,r);
}
poly c[N],nw;
int vis[N<<1],head[N],nxt[N<<1],tot=0;
inline void adde(int x,int y){vis[++tot]=y;nxt[tot]=head[x];head[x]=tot;}
int fa[N],siz[N],sum[N];
int n,k,ans=0;
int fac[N],ifac[N];
inline void init(){
fac[0]=fac[1]=1;
ifac[0]=ifac[1]=1;
for(int i=2;i<=k;i++) fac[i]=mul(fac[i-1],i),ifac[i]=mul(ifac[mod-mod/i*i],mod-mod/i);
for(int i=2;i<=k;i++) ifac[i]=mul(ifac[i-1],ifac[i]);
}
void dfs1(int v){
siz[v]=1;
for(int i=head[v];i;i=nxt[i]){
int y=vis[i];
if(y==fa[v]) continue;
fa[y]=v;dfs1(y);
siz[v]+=siz[y];
inc(ans,mul(sum[v],sum[y])),inc(sum[v],sum[y]);
}
int val=0;top=0;
for(int i=head[v];i;i=nxt[i]) if(vis[i]!=fa[v]) tmp[++top]=siz[vis[i]];
if(top){
c[v]=solve(1,top);
for(int i=min(k,top);~i;i--) inc(val,mul(ifac[k-i],c[v][i]));
inc(sum[v],mul(val,fac[k]));
}
else sum[v]=1;
}
map<int,int>mp[N];
void dfs2(int v){
for(int i=head[v];i;i=nxt[i]){
int y=vis[i];
if(y==fa[v]) continue;
dfs2(y);
inc(mp[v][siz[y]],sum[y]);
}
for(map<int,int>::iterator it=mp[v].begin();it!=mp[v].end();it++){
int w=0;nw=c[v];
div(nw,it->fi),Mul(nw,n-siz[v]);
for(int i=min((int)c[v].size()-1-(v==1),k);~i;i--) inc(w,mul(ifac[k-i],nw[i]));
inc(ans,mul(mul(w,fac[k]),it->se));
}
}
int main(){
n=read(),k=read();
if(k==1) {cout<<(ll)n*(n-1)/2%mod;return 0;}
init();
for(int i=1,x,y;i<n;i++) x=read(),y=read(),adde(x,y),adde(y,x);
dfs1(1);dfs2(1);
cout<<ans;
return 0;
}