LOJ#2983. 「WC2019」数树 排列组合,生成函数,多项式,FFT

原文链接www.cnblogs.com/zhouzhendong/p/LOJ2983.html

前言

我怎么什么都不会?贺忙指导博客才会做。

题解

我们分三个子问题考虑。

子问题0

将红蓝共有的边连接,每一个连通块的颜色相同,不同连通块独立。

答案是 \(y ^ {连通块数}\)

子问题1

对于红树的一种连接方案,假设将在蓝树上也有的边连接起来,假设连了 \(i\) 条边,那么对答案的贡献就是:

\[y ^ n / y ^ i \]

\[z = \frac 1 y \]

根据二项式定理

\[z ^ a = \sum_{i=0}^a \binom{a}{i} (z-1)^i\]

于是得到贡献是

\[\sum_{j=0}^{n-i} \binom{n-i}{j} (z -1) ^ j\]

组合意义就是枚举所有边的子集算答案。

所以答案是

\[y ^ n \sum_{i = 0} ^ {n-1} (z-1) ^ j \sum n ^ {n-i-2} \prod_k a_k\]

其中 \(a_k\) 表示第 \(k\) 个连通块的大小。

考虑进一步展开组合意义:

\(\prod _k a_k\) 的含义就是每一个连通块取一个点的方案数,所以对蓝树进行树形DP,用 dp[x][0] 表示当前连通块没有选点的方案数,dp[x][1] 表示当前连通块已经选了一个的方案数。大力转移即可。

时间复杂度 \(O(n)\)

子问题2

考虑写出答案的式子

\[ans = y ^ n \sum_{i=1}^ n (z- 1 ) ^ {n-i} \frac{n!}{i!\prod a_j!}\left (\prod a_j^{a_j} \right)(n ^ {i-2}) ^ 2 \\ = y ^ n n ^ {-4} (z-1) ^ n \sum_{i=1}^ n \frac{n!}{i!\prod a_j!}\prod (a_j^{a_j} (z-1)^ {-1}n ^ 2) \]

注意到

\[\sum_{i=1}^ n \frac{n!}{i!\prod a_j!}\prod (a_j^{a_j} (z-1)^ {-1}n ^ 2) = [n] exp(\sum_{i\geq 1 } a_j^{a_j} (z-1)^ {-1}n ^ 2\frac{x^i}{i!})\]

于是运用多项式 exp 即可在 \(O(n\log n)\) 的时间复杂度内解决这个问题。

代码

#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof x)
#define For(i,a,b) for (int i=(a);i<=(b);i++)
#define Fod(i,b,a) for (int i=(b);i>=(a);i--)
#define fi first
#define se second
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define outval(x) cerr<<#x" = "<<x<<endl
#define outtag(x) cerr<<"---------------"#x"---------------"<<endl
#define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";\
                        For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl;
using namespace std;
typedef long long LL;
LL read(){
    LL x=0,f=0;
    char ch=getchar();
    while (!isdigit(ch))
        f|=ch=='-',ch=getchar();
    while (isdigit(ch))
        x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return f?-x:x;
}
const int mod=998244353;
int Pow(int x,int y){
    if (y<0)
        x=Pow(x,mod-2),y=-y;
    int ans=1;
    for (;y;y>>=1,x=(LL)x*x%mod)
        if (y&1)
            ans=(LL)ans*x%mod;
    return ans;
}
void Add(int &x,int y){
    if ((x+=y)>=mod)
        x-=mod;
}
void Del(int &x,int y){
    if ((x-=y)<0)
        x+=mod;
}
int Add(int x){
    return x>=mod?x-mod:x;
}
int Del(int x){
    return x<0?x+mod:x;
}
const int N=(1<<19)+1;
int Fac[N],Inv[N],Iv[N];
void getFI(){
    int n=N-1;
    for (int i=Fac[0]=1;i<=n;i++)
        Fac[i]=(LL)Fac[i-1]*i%mod;
    Inv[n]=Pow(Fac[n],mod-2);
    Fod(i,n,1)
        Inv[i-1]=(LL)Inv[i]*i%mod;
    For(i,1,n)
        Iv[i]=(LL)Inv[i]*Fac[i-1]%mod;
}
namespace fft{
     int w[N],R[N];
     void init(int n){
        int d=0;
        while ((1<<d)<n)
            d++;
        For(i,0,n-1)
            R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
        w[0]=1,w[1]=Pow(3,(mod-1)/n);
        For(i,2,n-1)
            w[i]=(LL)w[i-1]*w[1]%mod;
     }
     void FFT(int *a,int n,int flag){
        if (flag<0)
            reverse(w+1,w+n);
        For(i,0,n-1)
            if (i<R[i])
                swap(a[i],a[R[i]]);
        for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
            for (int i=0;i<n;i+=d<<1)
                for (int j=0;j<d;j++){
                    int tmp=(LL)w[t*j]*a[i+j+d]%mod;
                    a[i+j+d]=Del(a[i+j]-tmp);
                    Add(a[i+j],tmp);
                }
        if (flag<0){
            reverse(w+1,w+n);
            int inv=Pow(n,mod-2);
            For(i,0,n-1)
                a[i]=(LL)a[i]*inv%mod;
        }
     }
}
using fft::FFT;
typedef vector <int> vi;
vi Fix(vi a,int n){
    while (a.size()>n)
        a.pop_back();
    while (a.size()<n)
        a.pb(0);
    return a;
}
vi operator * (vi a,vi b){
    int s=(int)a.size()+b.size()-1,n=1;
    while (n<s)
        n<<=1;
    a=Fix(a,n),b=Fix(b,n);
    fft::init(n);
    FFT(&a[0],n,1),FFT(&b[0],n,1);
    For(i,0,n-1)
        a[i]=(LL)a[i]*b[i]%mod;
    FFT(&a[0],n,-1);
    return Fix(a,s);
}
vi operator + (vi a,vi b){
    int s=max(a.size(),b.size());
    a=Fix(a,s),b=Fix(b,s);
    For(i,0,s-1)
        Add(a[i],b[i]);
    return a;
}
vi operator - (vi a,vi b){
    int s=max(a.size(),b.size());
    a=Fix(a,s),b=Fix(b,s);
    For(i,0,s-1)
        Del(a[i],b[i]);
    return a;
}
vi pInv(vi a){
    if (a.size()==1)
        return (vi){Pow(a[0],mod-2)};
    int n=a.size();
    vi b=pInv(Fix(a,(n+1)>>1));
    return Fix(b+b-b*b*a,n);
}
vi Der(vi a){
    int n=a.size();
    For(i,0,n-2)
        a[i]=(LL)a[i+1]*(i+1)%mod;
    return Fix(a,n-1);
}
vi Int(vi a){
    int n=a.size();
    a.pb(0);
    Fod(i,n,1)
        a[i]=(LL)a[i-1]*Iv[i]%mod;
    a[0]=0;
    return a;
}
vi Ln(vi a){
    return Int(Fix(Der(a)*pInv(a),a.size()-1));
}
vi Exp(vi a){
    if (a.size()==1)
        return (vi){1};
    int n=a.size();
    vi b=Fix(Exp(Fix(a,(n+1)>>1)),n);
    return Fix(b*((vi){1}-Ln(b)+a),n);
}
int n,z,op;
namespace so0{
    map <pair <int,int>,int> Map;
    int main(){
        Map.clear();
        For(i,1,n-1){
            int x=read(),y=read();
            if (x>y)
                swap(x,y);
            Map[mp(x,y)]=1;
        }
        int c=n;
        For(i,1,n-1){
            int x=read(),y=read();
            if (x>y)
                swap(x,y);
            c-=Map[mp(x,y)];
        }
        cout<<Pow(z,c)<<endl;
        return 0;
    }
}
namespace so1{
    int inv_n,izn;
    vector <int> e[N];
    int size[N];
    int dp[N][2];
    void dfs(int x,int pre){
        dp[x][0]=dp[x][1]=1;
        for (auto y : e[x])
            if (y!=pre){
                dfs(y,x);
                int t0=dp[x][0],t1=dp[x][1];
                dp[x][0]=(LL)t0*dp[y][1]%mod;
                dp[x][1]=(LL)t1*dp[y][1]%mod;
                Add(dp[x][0],(LL)t0*dp[y][0]%mod*izn%mod);
                Add(dp[x][1],(LL)t0*dp[y][1]%mod*izn%mod);
                Add(dp[x][1],(LL)t1*dp[y][0]%mod*izn%mod);
            }
    }
    int main(){
        if (z==1){
            cout<<Pow(n,n-2)<<endl;
            return 0;
        }
        inv_n=Pow(n,mod-2);
        izn=Del(Pow(z,mod-2)-1);
        izn=(LL)izn*inv_n%mod;
        For(i,1,n-1){
            int x=read(),y=read();
            e[x].pb(y),e[y].pb(x);
        }
        dfs(1,0);
        int ans=(LL)dp[1][1]*Pow(n,n-2)%mod*Pow(z,n)%mod;
        cout<<ans<<endl;
        return 0;
    }
}
namespace so2{
    int main(){
        if (z==1){
            cout<<Pow(n,(n-2)*2)<<endl;
            return 0;
        }
        getFI();
        int iz=Del(Pow(z,mod-2)-1),tmp=(LL)Pow(iz,-1)*n%mod*n%mod;
        vi a;
        a.pb(0);
        For(i,1,n)
            a.pb((LL)Pow(i,i)*tmp%mod*Inv[i]%mod);
        a=Exp(a);
        int ans=(LL)a[n]*Fac[n]%mod;
        ans=(LL)ans*Pow(z,n)%mod*Pow(iz,n)%mod*Pow(n,-4)%mod;
        cout<<ans<<endl;
        return 0;
    }
}
int main(){
    n=read(),z=read(),op=read();
    if (op==0)
        return so0::main();
    else if (op==1)
        return so1::main();
    else if (op==2)
        return so2::main();
    return 0;
}

转载于:https://www.cnblogs.com/zhouzhendong/p/LOJ2983.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值