[校内模拟][树形DP][组合数学][NTT]战略游戏

Link

就是要求定下两个点,并从这两个点往外延伸k个互不相交的分支的方案数
首先求出从某个点出发向子树内部延伸i个分支的方案数
对于一个点的一个儿子,这个儿子的子树中每一个点都可能成为一条分支的终点,这个儿子内就有siz个终点
不同儿子的子树互不影响,所以根据乘法原理,可以构造点v的生成函数: F ( x ) = ∏ ( 1 + s i z s o n [ v ] ) F(x)=\prod{(1+siz_{son[v]})} F(x)=(1+sizson[v])
同时一个分支的端点可能与点v重合,考虑有x个不重合,则点v的答案就是 ∑ x = 0 k A k x f [ i ] \sum_{x=0}^k{A_k^xf[i]} x=0kAkxf[i],把两两点乘起来统计答案, F ( x ) F(x) F(x)可以用分治NTT来求
这就是两个点不为祖先-后代的情况
否则假设这两个点为u,v,u为v的祖先,s为u的儿子且为v的祖先
我们设另一个生成函数 G ( x ) = F ( x ) ∗ 1 + ( n − s i z u ) x 1 + s i z s x G(x)=F(x)*\frac{1+(n-siz_u)x}{1+siz_sx} G(x)=F(x)1+sizsx1+(nsizu)x,表示u不能向v的方向延伸分支,且可以向上延伸分支(把u的子树以外的部分看作u的另一棵子树),则点u的答案就是 ∑ x = 0 k A k x g [ i ] \sum_{x=0}^k{A_k^xg[i]} x=0kAkxg[i],同样把两两点乘起来统计答案,v的答案用 F ( x ) F(x) F(x)表示

分治NTT是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),后面的枚举 s i z siz siz O ( n n ) O(n\sqrt{n}) O(nn ),因为一个子树的不同 s i z siz siz个数不超过 n \sqrt{n} n ,总复杂度 O ( n l o g 2 n + n n ) O(nlog^2n+n\sqrt{n}) O(nlog2n+nn )

Code:

#include<bits/stdc++.h>
#define poly vector<int>
#define mod 998244353
#define pb push_back
#define ll long long
#define fi first
#define se second
using namespace std;
inline int read(){
    int res=0,f=1;char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
    while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
    return res*f;
}
inline int add(int x,int y){x+=y;if(x>=mod) x-=mod;return x;}
inline int dec(int x,int y){x-=y;if(x<0) x+=mod;return x;}
inline int mul(int x,int y){return (ll)x*y%mod;}
inline void inc(int &x,int y){x+=y;if(x>=mod) x-=mod;}
inline int ksm(int a,int b){int res=1;for(;b;b>>=1,a=mul(a,a)) if(b&1) res=mul(res,a);return res;}
const int N=1e5+5;
namespace Ntt{
    int lim,t,rev[N<<2];
    inline void init_rev(int len){
     lim=1,t=0;
     while(lim<=len) lim<<=1,++t;
     for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(t-1));
 }
 inline void ntt(poly &a,int kd){
     for(int i=0;i<lim;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
     int p=(mod-1)/2,w,wn,g=kd==1?3:(mod+1)/3,a0,a1;
     for(int mid=1;mid<lim;mid<<=1,p>>=1){
         wn=ksm(g,p);
         for(int j=0,len=mid<<1;j<lim;j+=len){
             w=1;
             for(int k=0;k<mid;k++,w=mul(w,wn)){
                 a0=a[j+k],a1=mul(a[j+k+mid],w);
                 a[j+k]=add(a0,a1);a[j+k+mid]=dec(a0,a1);
             }
         }
     }
     if(kd==-1) for(int i=0,inv=ksm(lim,mod-2);i<lim;i++) a[i]=mul(a[i],inv);
    }
}
using namespace Ntt;
inline void Mul(poly &a,int b){
    a.pb(0);
    for(int i=a.size()-2;~i;i--) inc(a[i+1],mul(a[i],b));
}
inline void div(poly &a,int b){
    int inv=ksm(b,mod-2);
    poly t=a;
    for(int i=a.size()-1;i;i--) a[i-1]=mul(t[i],inv),t[i-1]=dec(t[i-1],a[i-1]);
    a.pop_back(); 
}
inline poly operator *(poly a,poly b){
    int n=a.size()-1,m=b.size()-1;
    init_rev(n+m);
    a.resize(lim),b.resize(lim);
    ntt(a,1),ntt(b,1);
    for(int i=0;i<lim;i++) a[i]=mul(a[i],b[i]);
    ntt(a,-1);
    return a;
}
int tmp[N],top;
inline poly solve(int l,int r){
    if(l==r) {poly c;c.pb(1),c.pb(tmp[l]);return c;}
    int mid=l+r>>1;
    return solve(l,mid)*solve(mid+1,r);
}
poly c[N],nw;
int vis[N<<1],head[N],nxt[N<<1],tot=0;
inline void adde(int x,int y){vis[++tot]=y;nxt[tot]=head[x];head[x]=tot;}
int fa[N],siz[N],sum[N];
int n,k,ans=0;
int fac[N],ifac[N];
inline void init(){
    fac[0]=fac[1]=1;
    ifac[0]=ifac[1]=1;
    for(int i=2;i<=k;i++) fac[i]=mul(fac[i-1],i),ifac[i]=mul(ifac[mod-mod/i*i],mod-mod/i);
    for(int i=2;i<=k;i++) ifac[i]=mul(ifac[i-1],ifac[i]); 
}
void dfs1(int v){
    siz[v]=1;
    for(int i=head[v];i;i=nxt[i]){
        int y=vis[i];
        if(y==fa[v]) continue;
        fa[y]=v;dfs1(y);
        siz[v]+=siz[y];
        inc(ans,mul(sum[v],sum[y])),inc(sum[v],sum[y]);
    }
    int val=0;top=0;
    for(int i=head[v];i;i=nxt[i]) if(vis[i]!=fa[v]) tmp[++top]=siz[vis[i]];
        if(top){
            c[v]=solve(1,top);
            for(int i=min(k,top);~i;i--) inc(val,mul(ifac[k-i],c[v][i]));
            inc(sum[v],mul(val,fac[k]));
        }
    else sum[v]=1;
}
map<int,int>mp[N];
void dfs2(int v){
    for(int i=head[v];i;i=nxt[i]){
        int y=vis[i];
        if(y==fa[v]) continue;
        dfs2(y);
        inc(mp[v][siz[y]],sum[y]);
    }
    for(map<int,int>::iterator it=mp[v].begin();it!=mp[v].end();it++){
        int w=0;nw=c[v];
        div(nw,it->fi),Mul(nw,n-siz[v]);
        for(int i=min((int)c[v].size()-1-(v==1),k);~i;i--) inc(w,mul(ifac[k-i],nw[i]));
        inc(ans,mul(mul(w,fac[k]),it->se));
    }
}
int main(){
    n=read(),k=read();
    if(k==1) {cout<<(ll)n*(n-1)/2%mod;return 0;}
    init();
    for(int i=1,x,y;i<n;i++) x=read(),y=read(),adde(x,y),adde(y,x);
    dfs1(1);dfs2(1);
    cout<<ans;
    return 0;
} 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值