洛谷P4491 二项式反演,组合计数

题意:

有一个纸带,上面有 n n n个空位,每个空位可以涂 m m m种颜色任意之一,涂完颜色后,如果纸带上出现次数为 s s s的颜色有 k k k种,那么会获得 w k w_{k} wk的喜悦度,求所有涂色方案的喜悦度之和

Solution:

f i f_{i} fi为涂完色后纸带上的出现次数为 s s s的颜色数量为 i i i的涂色方案数, M M M为最多存在几种颜色出现次数为 s s s,那么

M = m i n { m , n s } , a n s = ∑ i = 0 M f i w i M=min\{m,\frac{n}{s}\},ans=\sum_{i=0}^{M}f_{i}w_{i} M=min{m,sn},ans=i=0Mfiwi

求一个序列上有多少满足某个限制的可以考虑设 g i g_{i} gi为钦定 i i i种符合条件,这样就得到了至少次数,这里设 g i g_{i} gi为钦定 i i i种颜色,他们的出现次数为 s s s的方案种数,那么

g i = C m i A n i × s ( s ! ) k ( m − i ) n − i × s g_{i}=C_{m}^{i}\frac{A_{n}^{i\times s}}{(s!)^k}(m-i)^{n-i\times s} gi=Cmi(s!)kAni×s(mi)ni×s

含义为:先选出 i i i种颜色钦定,然后选出一个长 i × s i\times s i×s的排列来放置这些颜色,由于相同颜色不再加以区分,所以每种颜色需要除他们长度的全排列,有 i i i种,即除 ( s ! ) k (s!)^k (s!)k,然后剩下的位置每个位置都可以选择 ( m − i ) (m-i) (mi)种颜色

那么 f i f_{i} fi

f i = g i − ∑ j = i + 1 M C j i f j f_{i}=g_{i}-\sum_{j=i+1}^{M}C_{j}^{i}f_{j} fi=gij=i+1MCjifj

g i g_{i} gi里面钦定了 i i i种颜色,里面实际上包含了 C j i f j C_{j}^{i}f_{j} Cjifj,而不是仅包含了 f j f_{j} fj,因为长度 j j j里面选择 i i i个钦定可以有 C j i C_{j}^{i} Cji种选法。不妨强调一下 g i g_{i} gi的定义,是钦定了 i i i个颜色出现次数为 s s s,而不是至少有 i i i个颜色出现次数为 s s s

由于当 j = i j=i j=i时, C j i f j = f i C_{j}^{i}f_{j}=f_{i} Cjifj=fi,所以移项得到

g i = ∑ j = i M C j i f j g_{i}=\sum_{j=i}^{M}C_{j}^{i}f_{j} gi=j=iMCjifj

二项式反演得到

f i = ∑ j = i M ( i − 1 ) j − i C j i g j f_{i}=\sum_{j=i}^{M}(i-1)^{j-i}C_{j}^{i}g_{j} fi=j=iM(i1)jiCjigj

展开后得到

f i = n ! m ! i ! ∑ j = i M ( − 1 ) j − i ( m − j ) n − j × s ( j − i ) ! ( m − j ) ! ( n − j × s ) ! ( s ! ) j f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}\frac{(-1)^{j-i}(m-j)^{n-j\times s}}{(j-i)!(m-j)!(n-j\times s)!(s!)^j} fi=i!n!m!j=iM(ji)!(mj)!(nj×s)!(s!)j(1)ji(mj)nj×s

s ( i ) = ( − 1 ) i i ! , t ( i ) = ( m − i ) n − i × s ( m − i ) ! ( n − i × s ) ! ( s ! ) i s(i)=\frac{(-1)^i}{i!},t(i)=\frac{(m-i)^{n-i\times s}}{(m-i)!(n-i\times s)!(s!)^i} s(i)=i!(1)i,t(i)=(mi)!(ni×s)!(s!)i(mi)ni×s

则有

f i = n ! m ! i ! ∑ j = i M s ( j − i ) t ( j ) f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}s(j-i)t(j) fi=i!n!m!j=iMs(ji)t(j)

右边是一个类卷积形式,可以有如下两种方法转化为标准卷积形式

法一:

卷积下标相加为定值,不妨将 s ( j − i ) s(j-i) s(ji)转换为 s ( i − j ) s(i-j) s(ij),这样就可以达成目的,于是考虑重新令

s ( i ) = ( − 1 ) − i ( − i ) ! s(i)=\frac{(-1)^{-i}}{(-i)!} s(i)=(i)!(1)i

这样原式就为

f i = n ! m ! i ! ∑ j = i M s ( i − j ) t ( j ) f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}s(i-j)t(j) fi=i!n!m!j=iMs(ij)t(j)

但是 s ( i ) s(i) s(i)这个函数有负数下标,存不进数组,我们考虑将他平移 M M M个单位,即再重新令

s ( i ) = ( − 1 ) M − i ( M − i ) ! s(i)=\frac{(-1)^{M-i}}{(M-i)!} s(i)=(Mi)!(1)Mi

那么原式即

f i = n ! m ! i ! ∑ j = i M s ( M + i − j ) t ( j ) f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}s(M+i-j)t(j) fi=i!n!m!j=iMs(M+ij)t(j)

相加为定值 M + i M+i M+i,可以顺利卷积,答案即卷积结果的第 M + i M+i M+i

法二:

s ( x ) s(x) s(x)反转,原式变为

f i = n ! m ! i ! ∑ j = i M s ( M − ( j − i ) ) t ( j ) f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}s(M-(j-i))t(j) fi=i!n!m!j=iMs(M(ji))t(j)

此时答案是卷积结果的第 M + i M+i M+i

或者将 t ( x ) t(x) t(x)反转,原式变为

f i = n ! m ! i ! ∑ j = i M s ( j − i ) t ( M − j ) f_{i}=\frac{n!m!}{i!}\sum_{j=i}^{M}s(j-i)t(M-j) fi=i!n!m!j=iMs(ji)t(Mj)

此时答案即卷积结果的第 M − i M-i Mi

// #include<bits/stdc++.h>
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
#include<cstdio>
#include<time.h>
using namespace std;

using ll=long long;
const int N=10000005,inf=0x3fffffff;
const long long INF=0x3f3f3f3f3f3f,mod=1004535809;

ll qpow(ll a,ll b)
{
    ll ret=1,base=a;
    while(b)
    {
        if(b&1) ret=ret*base%mod;
        base=base*base%mod;
        b>>=1;
    }
    return ret;
}

ll inv(ll x){return qpow(x,mod-2);}

const long long inv2=inv(2),inv3=inv(3);

namespace poly
{
    //多项式开根二次剩余部分
    ll si;
    struct Complex_MOD
    {
        ll a,b;
        Complex_MOD operator*(const Complex_MOD &t) const
        {
            Complex_MOD res;
            res.a=(a*t.a%mod+b*t.b%mod*si%mod)%mod;
            res.b=(a*t.b%mod+b*t.a%mod)%mod;
            return res;
        }
    };
    inline ll Pow_Complex(Complex_MOD val,ll b)
    {
        Complex_MOD res={1,0};
        while(b)
        {
            if(b&1) res=res*val;
            val=val*val;
            b>>=1;
        }
        return res.a;
    }
    ll squaremod(ll n)
    {
        if(!n) return 0;
        srand((unsigned)(time(NULL)));
        ll p1,p2;
        while(1)
        {
            p1=1ll*rand()*rand()%mod;
            p2=(p1*p1%mod-n+mod)%mod;
            if(::qpow(p2,(mod-1)/2)!=1) break;
        }
        si=p2;
        Complex_MOD val={p1,1};
        int ans=Pow_Complex(val,(mod+1)/2);
        if(mod-ans<ans) ans=mod-ans;
        return ans;
    }
    //以上是二次剩余
    int pos[N],invs[N],invcnt,sqrts[N],sqrtcnt,exps[N],expcnt;
    ll a[N],b[N],invtmp[N],dertmp[N],multitmp[N],lntmp[N],lnlntmp[N],qpowtmp[N];
    int getlen(int k)
    {
        int ret=0;
        while(k){ret++;k>>=1;}
        return ret;
    }
    int getrev(int k,int len)
    {
        int ret=0;
        while(k){ret=(ret<<1|(k&1));k>>=1;len--;}
        return ret<<len;
    }
    int getpos(int n)
    {
        int limit=1;
        while(limit<=n) limit<<=1;
        int len=getlen(limit-1);
        for(int i=0;i<limit;i++) pos[i]=getrev(i,len);
        return limit;
    }
    void ntt(ll *a,int limit,int op)
    {
        for(int i=0;i<limit;i++)
            if(i<pos[i]) swap(a[i],a[pos[i]]);
        for(int len=2;len<=limit;len<<=1)
        {
            ll base=::qpow(op==1?3:inv3,(mod-1)/len);
            for(int l=0;l<limit;l+=len)
            {
                ll now=1;
                for(int i=l;i<l+len/2;i++)
                {
                    ll x=a[i]%mod,y=now*a[i+len/2]%mod;
                    a[i]=(x+y)%mod;
                    a[i+len/2]=(x-y+mod)%mod;
                    now=now*base%mod;
                }
            }
        }
    }
    void prepare(int *s,int &cnt,int n)
    {
        cnt=0;
        while(n>1) s[++cnt]=n,n=n+1>>1;
    }
    void inv(ll *f,ll *g,int n)//求出f的乘法逆元,放到g内,f的长度为n,需要保证g是空的
    {
        prepare(invs,invcnt,n); 
        g[0]=::inv(f[0]);
        for(int i=invcnt;i>=1;i--)
        {
            int limit=getpos(invs[i]<<1);
            memcpy(a,f,sizeof(ll)*invs[i]); fill(a+invs[i],a+limit,0);
            ntt(a,limit,1); ntt(g,limit,1);
            for(int i=0;i<limit;i++) g[i]=g[i]*((2-a[i]*g[i]%mod)%mod+mod)%mod;
            ntt(g,limit,-1); ll tmp=::inv(limit);
            for(int i=0;i<limit;i++) g[i]=g[i]*tmp%mod;
            fill(g+invs[i],g+limit,0);
        }
        memset(a,0,sizeof(a));
    }
    void sqrt(ll *f,ll *g,int n)//多项式f开根,存放在g内,需要保证g是空的
    {
        prepare(sqrts,sqrtcnt,n); g[0]=squaremod(f[0]);//二次剩余
        for(int i=sqrtcnt;i>=1;i--)
        {
            int limit=getpos(sqrts[i]<<1);
            memset(invtmp,0,sizeof(ll)*limit); inv(g,invtmp,sqrts[i]);
            memcpy(a,f,sizeof(ll)*sqrts[i]); fill(a+sqrts[i],a+limit,0);
            ntt(g,limit,1); ntt(invtmp,limit,1); ntt(a,limit,1);
            for(int i=0;i<limit;i++) g[i]=inv2*(g[i]+invtmp[i]*a[i]%mod)%mod;
            ntt(g,limit,-1); ll tmp=::inv(limit);
            for(int i=0;i<limit;i++) g[i]=g[i]*tmp%mod;
            fill(g+sqrts[i],g+limit,0);
        }
    }
    void derivation(ll *f,ll *g,int n)//f求导,放入g
    {
        for(int i=0;i<n-1;i++) g[i]=f[i+1]*(i+1);
        g[n-1]=0;
    }
    void integral(ll *f,ll *g,int n)//f积分,放入g
    {
        g[0]=0;
        for(int i=1;i<n;i++) g[i]=f[i-1]*::inv(i)%mod;
    }
    void multi(ll *f,ll *g,ll *t,int n,int m)//f*g放入t
    {
        int limit=getpos(n+m);
        memcpy(a,f,sizeof(ll)*n); fill(a+n,a+limit,0);
        memcpy(t,g,sizeof(ll)*m); fill(t+m,t+limit,0);
        ntt(a,limit,1); ntt(t,limit,1);
        for(int i=0;i<limit;i++) t[i]=t[i]*a[i]%mod;
        ntt(t,limit,-1); ll tmp=::inv(limit);
        for(int i=0;i<limit;i++) t[i]=t[i]*tmp%mod;
        fill(t+n+m,t+limit,0);
    }
    void ln(ll *f,ll *g,int n)
    {
        derivation(f,dertmp,n);
        int limit=1;
        while(limit<=(n<<1)) limit<<=1;
        memset(invtmp,0,sizeof(ll)*limit);
        inv(f,invtmp,n);
        multi(dertmp,invtmp,multitmp,n,n);
        integral(multitmp,g,n);
    }
    void exp(ll *f,ll *g,int n)//e^(f)放入g,需要保证g是空的
    {
        prepare(exps,expcnt,n); g[0]=1;
        for(int i=expcnt;i>=1;i--)
        {
            int limit=getpos(exps[i]<<1);
            ln(g,lntmp,exps[i]); fill(lntmp+exps[i],lntmp+limit,0);
            memcpy(a,f,sizeof(ll)*exps[i]); fill(a+exps[i],a+limit,0);
            ntt(g,limit,1); ntt(lntmp,limit,1); ntt(a,limit,1);
            for(int i=0;i<limit;i++) g[i]=g[i]*(((1-lntmp[i]+a[i])%mod+mod)%mod)%mod;
            ntt(g,limit,-1); ll tmp=::inv(limit);
            for(int i=0;i<limit;i++) g[i]=g[i]*tmp%mod;
            fill(g+exps[i],g+limit,0);
            memset(lntmp,0,sizeof(ll)*limit);
        }
    }
    void read(char *s,int n,ll &k1,ll &k2,ll &k3)
    {
        k1=k2=k3=0; int len=strlen(s+1);
        for(int i=1;i<=len;i++)
        {
            k1=(10*k1+(s[i]-'0'))%mod; 
            k2=(10*k2+(s[i]-'0'))%(mod-1);
            if(k3<n) k3=10*k3+(s[i]-'0'); 
        }
    }
    void qpow(ll *f,ll *g,char *s,int n)//字符串形式给出幂,f^(k)放入g中,需要保证g是空的
    {
        ll k1,k2,k3; int sta=0;
        read(s,n,k1,k2,k3);
        while(sta<n&&f[sta]==0) sta++;
        if(sta==n||sta*k1>n||(f[0]==0&&k3>=n))
        {
            for(int i=0;i<n;i++) g[i]=0;
            return;
        }
        ll invsta=::inv(f[sta]),tmp=::qpow(f[sta],k2);
        for(int i=0;i<n;i++) qpowtmp[i]=f[i];
        for(int i=0;i<n;i++) qpowtmp[i]=qpowtmp[i+sta]*invsta%mod;
        ln(qpowtmp,lnlntmp,n);
        for(int i=0;i<n;i++) lnlntmp[i]=lnlntmp[i]*k1%mod;
        exp(lnlntmp,g,n);
        for(int i=n-1;i>=sta*k1;i--) g[i]=g[i-sta*k1]*tmp%mod;
        for(int i=0;i<sta*k1;i++) g[i]=0;
    }
};

int n,m,s,M;
ll w[N],fac[N],invfac[N],f[N],g[N],t[N],ans;

ll F(ll i){return ((((M-i)&1?-1:1)*invfac[M-i])%mod+mod)%mod;}
ll G(ll i){return ::qpow(m-i,n-i*s)*invfac[m-i]%mod*invfac[n-i*s]%mod*::qpow(invfac[s],i)%mod;}

int main()
{
    cin>>n>>m>>s; M=min(n/s,m);
    for(int i=0;i<=m;i++) scanf("%lld",&w[i]);
    for(int i=fac[0]=invfac[0]=1;i<N;i++) fac[i]=fac[i-1]*i%mod;
    invfac[N-1]=::inv(fac[N-1]);
    for(int i=N-2;i>=1;i--) invfac[i]=invfac[i+1]*(i+1)%mod;
    for(int i=0;i<=M;i++) f[i]=F(1ll*i),g[i]=G(1ll*i);
    poly::multi(f,g,t,M+1,M+1);
    for(int i=0;i<=M;i++) ans=(ans+invfac[i]*t[i+M]%mod*w[i]%mod)%mod;
    cout<<ans*fac[n]%mod*fac[m]%mod;
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值