HDU--4248、A Famous Stone Collector (计数类dp)

题目链接

题面:
在这里插入图片描述
题意:
n n n 种颜色的石子,第 i i i 种颜色的石子有 a i a_i ai 个。
我现在要从里面选出若干个石子排成一排,问有多少种不同的序列。

两个序列不同当且仅当长度不同或者至少有某个位置的对应颜色不同。

题解:
考虑 d p [ i ] [ j ] dp[i][j] dp[i][j],表示我考虑的前 i i i 种颜色的石子,现在构成的长度为 j j j 的序列的不同排列的方案数。

转移方程显然 d p [ i ] [ j ] = ∑ k = 0 m i n ( s u m [ i ] , j ) d p [ i − 1 ] [ j − k ] ∗ C j k dp[i][j]=\sum\limits_{k=0}^{min(sum[i],j)}dp[i-1][j-k]*C_j^k dp[i][j]=k=0min(sum[i],j)dp[i1][jk]Cjk

虽然这样的理论时间复杂度大概是 O ( T ∗ 1 e 8 ) O(T*1e8) O(T1e8) T T T 为数据的组数。

但是能过。

#pragma GCC optimize(3)
#pragma GCC optimize("Ofast","inline","-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
//#include<unordered_map>
#include<set>
//#include<unordered_set>
namespace onlyzhao
{
    #define ui unsigned int
    #define ll long long
    #define llu unsigned ll
    #define ld long double
    #define pr make_pair
    #define pb push_back
    #define lc (cnt<<1)
    #define rc (cnt<<1|1)
    #define len(x)  (t[(x)].r-t[(x)].l+1)
    #define tmid ((l+r)>>1)
    #define fhead(x) for(int i=head[(x)];i;i=nt[i])
    #define max(x,y) ((x)>(y)?(x):(y))
    #define min(x,y) ((x)>(y)?(y):(x))
    #define one(n) for(int i=1;i<=(n);i++)
    #define rone(n) for(int i=(n);i>=1;i--)
    #define fone(i,x,n) for(int i=(x);i<=(n);i++)
    #define frone(i,n,x) for(int i=(n);i>=(x);i--)
    #define fonk(i,x,n,k) for(int i=(x);i<=(n);i+=(k))
    #define fronk(i,n,x,k) for(int i=(n);i>=(x);i-=(k))
    #define two(n,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)
    #define ftwo(i,n,j,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)
    #define fvc(vc) for(int i=0;i<vc.size();i++)
    #define frvc(vc) for(int i=vc.size()-1;i>=0;i--)
    #define forvc(i,vc) for(int i=0;i<vc.size();i++)
    #define forrvc(i,vc) for(int i=vc.size()-1;i>=0;i--)
    #define cls(a) memset(a,0,sizeof(a))
    #define cls1(a) memset(a,-1,sizeof(a))
    #define clsmax(a) memset(a,0x3f,sizeof(a))
    #define clsmin(a) memset(a,0x80,sizeof(a))
    #define cln(a,num) memset(a,0,sizeof(a[0])*num)
    #define cln1(a,num) memset(a,-1,sizeof(a[0])*num)
    #define clnmax(a,num) memset(a,0x3f,sizeof(a[0])*num)
    #define clnmin(a,num) memset(a,0x80,sizeof(a[0])*num)
    #define sc(x) scanf("%d",&x)
    #define sc2(x,y) scanf("%d%d",&x,&y)
    #define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z)
    #define scl(x) scanf("%lld",&x)
    #define scl2(x,y) scanf("%lld%lld",&x,&y)
    #define scl3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)
    #define scf(x) scanf("%lf",&x)
    #define scf2(x,y) scanf("%lf%lf",&x,&y)
    #define scf3(x,y,z) scanf("%lf%lf%lf",&x,&y,&z)
    #define scs(x) scanf("%s",x+1)
    #define scs0(x) scanf("%s",x)
    #define scline(x) scanf("%[^\n]%*c",x+1)
    #define scline0(x) scanf("%[^\n]%*c",x)
    #define pcc(x) putchar(x)
    #define pc(x) printf("%d\n",x)
    #define pc2(x,y) printf("%d %d\n",x,y)
    #define pc3(x,y,z) printf("%d %d %d\n",x,y,z)
    #define pck(x) printf("%d ",x)
    #define pcl(x) printf("%lld\n",x)
    #define pcl2(x,y) printf("%lld %lld\n",x,y)
    #define pcl3(x,y,z) printf("%lld %lld %d\n",x,y,z)
    #define pclk(x) printf("%lld ",x)
    #define pcf2(x) printf("%.2f\n",x)
    #define pcf6(x) printf("%.6f\n",x)
    #define pcf8(x) printf("%.8f\n",x)
    #define pcs(x) printf("%s\n",x+1)
    #define pcs0(x) printf("%s\n",x)
    #define pcline(x) printf("%d**********\n",x)
    #define casett int tt;sc(tt);int pp=0;while(tt--)

    char buffer[100001],*S,*T;
    inline char Get_Char()
    {
        if (S==T)
        {
            T=(S=buffer)+fread(buffer,1,100001,stdin);
            if (S==T) return EOF;
        }
        return *S++;
    }
    inline ll read()
    {
        char c;ll re=0;
        for(c=Get_Char();c<'0'||c>'9';c=Get_Char());
        while(c>='0'&&c<='9') re=re*10+(c-'0'),c=Get_Char();
        return re;
    }
};
using namespace onlyzhao;
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=40100;
const int maxm=100100;
const int up=1e9;
const int maxp=1010;

int fac[maxn],inv[maxn];
int pc[maxn];

ll mypow(ll a,ll b)
{
    ll ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}


void init(void)
{
    fac[0]=1;
    for(int i=1;i<maxn;i++)
        fac[i]=1ll*fac[i-1]*i%mod;
    inv[maxn-1]=mypow(fac[maxn-1],mod-2);
    for(int i=maxn-2;i>=0;i--)
        inv[i]=1ll*inv[i+1]*(i+1)%mod;
}

int C(int n,int m)
{
    return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int dp[110][10010];

int main(void)
{
    int n;
    int pp=0;
    init();
    while(scanf("%d",&n)!=EOF)
    {
        int sum=0;
        for(int i=1;i<=n;i++)
            scanf("%d",&pc[i]);
        memset(dp,0,sizeof(dp));
        dp[0][0]=1;

        for(int i=1;i<=n;i++)
        {
            sum+=pc[i];
            for(int j=0;j<=sum;j++)
            {
                dp[i][j]=0;
                for(int k=0;k<=pc[i]&&k<=j;k++)
                    dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-k]*C(j,k))%mod;
            }
        }
        int ans=0;
        for(int i=1;i<=sum;i++)
            ans=(ans+dp[n][i])%mod;
        printf("Case %d: %d\n",++pp,ans);
    }
    return 0;
}

其实这个题,在写的时候用的生成函数写的,当时看过的不多,就以为是个数论题。
但是一看模数是 1 e 9 + 7 1e9+7 1e9+7 ,感觉也不像是用生成函数做。但是也没往 d p dp dp 上面想。
如果是 马哥 或者 冰哥 开了这个题的话,肯定就是秒过。

其实生成函数也能过,原 O J OJ OJ 开了 15 s 15s 15s 时限,跑了大约 5 s 5s 5s

#pragma GCC optimize(2)
#pragma GCC optimize("Ofast","inline","-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
//#include<unordered_map>
#include<set>
//#include<unordered_set>
namespace onlyzhao
{
    #define ui unsigned int
    #define ll long long
    #define llu unsigned ll
    #define ld long double
    #define pr make_pair
    #define pb push_back
    #define lc (cnt<<1)
    #define rc (cnt<<1|1)
    #define len(x)  (t[(x)].r-t[(x)].l+1)
    #define tmid ((l+r)>>1)
    #define fhead(x) for(int i=head[(x)];i;i=nt[i])
    #define max(x,y) ((x)>(y)?(x):(y))
    #define min(x,y) ((x)>(y)?(y):(x))
    #define one(n) for(int i=1;i<=(n);i++)
    #define rone(n) for(int i=(n);i>=1;i--)
    #define fone(i,x,n) for(int i=(x);i<=(n);i++)
    #define frone(i,n,x) for(int i=(n);i>=(x);i--)
    #define fonk(i,x,n,k) for(int i=(x);i<=(n);i+=(k))
    #define fronk(i,n,x,k) for(int i=(n);i>=(x);i-=(k))
    #define two(n,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)
    #define ftwo(i,n,j,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)
    #define fvc(vc) for(int i=0;i<vc.size();i++)
    #define frvc(vc) for(int i=vc.size()-1;i>=0;i--)
    #define forvc(i,vc) for(int i=0;i<vc.size();i++)
    #define forrvc(i,vc) for(int i=vc.size()-1;i>=0;i--)
    #define cls(a) memset(a,0,sizeof(a))
    #define cls1(a) memset(a,-1,sizeof(a))
    #define clsmax(a) memset(a,0x3f,sizeof(a))
    #define clsmin(a) memset(a,0x80,sizeof(a))
    #define cln(a,num) memset(a,0,sizeof(a[0])*num)
    #define cln1(a,num) memset(a,-1,sizeof(a[0])*num)
    #define clnmax(a,num) memset(a,0x3f,sizeof(a[0])*num)
    #define clnmin(a,num) memset(a,0x80,sizeof(a[0])*num)
    #define sc(x) scanf("%d",&x)
    #define sc2(x,y) scanf("%d%d",&x,&y)
    #define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z)
    #define scl(x) scanf("%lld",&x)
    #define scl2(x,y) scanf("%lld%lld",&x,&y)
    #define scl3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)
    #define scf(x) scanf("%lf",&x)
    #define scf2(x,y) scanf("%lf%lf",&x,&y)
    #define scf3(x,y,z) scanf("%lf%lf%lf",&x,&y,&z)
    #define scs(x) scanf("%s",x+1)
    #define scs0(x) scanf("%s",x)
    #define scline(x) scanf("%[^\n]%*c",x+1)
    #define scline0(x) scanf("%[^\n]%*c",x)
    #define pcc(x) putchar(x)
    #define pc(x) printf("%d\n",x)
    #define pc2(x,y) printf("%d %d\n",x,y)
    #define pc3(x,y,z) printf("%d %d %d\n",x,y,z)
    #define pck(x) printf("%d ",x)
    #define pcl(x) printf("%lld\n",x)
    #define pcl2(x,y) printf("%lld %lld\n",x,y)
    #define pcl3(x,y,z) printf("%lld %lld %d\n",x,y,z)
    #define pclk(x) printf("%lld ",x)
    #define pcf2(x) printf("%.2f\n",x)
    #define pcf6(x) printf("%.6f\n",x)
    #define pcf8(x) printf("%.8f\n",x)
    #define pcs(x) printf("%s\n",x+1)
    #define pcs0(x) printf("%s\n",x)
    #define pcline(x) printf("%d**********\n",x)
    #define casett int tt;sc(tt);int pp=0;while(tt--)

    char buffer[100001],*S,*T;
    inline char Get_Char()
    {
        if (S==T)
        {
            T=(S=buffer)+fread(buffer,1,100001,stdin);
            if (S==T) return EOF;
        }
        return *S++;
    }
    inline ll read()
    {
        char c;ll re=0;
        for(c=Get_Char();c<'0'||c>'9';c=Get_Char());
        while(c>='0'&&c<='9') re=re*10+(c-'0'),c=Get_Char();
        return re;
    }
};
using namespace onlyzhao;
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=40100;
const int maxm=100100;
const int up=1e9;
const int maxp=1010;

const int p=mod;
const int g=3;
int fi[maxn];
int n,m;
int aa[maxn],bb[maxn];
int pc[maxn];
ll fac[maxn],inv[maxn];




ll mypow(ll a,ll b,ll p=mod)
{

    if(b<0) return mypow(mypow(a,p-2,p),-b,p);
    a%=p;
    ll ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ans%p;
}

void init(void)
{
    fac[0]=1;
    for(int i=1;i<maxn;i++)
        fac[i]=fac[i-1]*i%mod;
    inv[maxn-1]=mypow(fac[maxn-1],mod-2);
    for(int i=maxn-2;i>=0;i--)
        inv[i]=inv[i+1]*(i+1)%mod;
}

struct Complex
{
    double x,y;
    Complex(double xx=0.0,double yy=0.0)
    {
        x=xx,y=yy;
    }
    Complex operator - (const Complex &b) const
    {
        return Complex(x-b.x,y-b.y);
    }

    Complex operator + (const Complex &b) const
    {
        return Complex(x+b.x,y+b.y);
    }

    Complex operator * (const Complex &b) const
    {
        return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
    }
};
int ans[maxn];
Complex a1[maxn],b1[maxn],a2[maxn],b2[maxn],ww[maxn],a[maxn];


void fft(Complex *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        for(int r=i<<1,j=0;j<len;j+=r)
        {
            for(int k=0;k<i;k++)
            {
                Complex w=ww[len/i*k];
                w.y*=f;

                Complex xx=x[j+k],yy=w*x[j+i+k];
                x[j+k]=xx+yy;
                x[j+i+k]=xx-yy;
            }
        }
    }
    if(f==-1)
        for(int i=0;i<len;i++)
            x[i].x/=len;
}

void get(Complex *x,Complex *y,int len,int pm)
{
    for(int i=0;i<len;i++) a[i]=x[i]*y[i];
    fft(a,len,-1);
    for(int i=0;i<len;i++)
        ans[i]=(ans[i]+(ll)(a[i].x+0.5)%p*1ll*pm)%p;
}

void get(int n,int m,int *a,int *b,int p=mod)
{

    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
    {
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
        ww[i]=Complex(cos(pi/len*i),sin(pi/len*i));
    }

    int pm=32768;
    int x;
    for(int i=0;i<=n;i++)
    {
        x=a[i];
        a1[i].x=x/pm,b1[i].x=x%pm;
        a1[i].y=0,b1[i].y=0;
    }
    for(int i=n+1;i<=len;i++)
        a1[i].x=a1[i].y=b1[i].x=b1[i].y=0;

    for(int i=0;i<=m;i++)
    {
        x=b[i];
        a2[i].x=x/pm,b2[i].x=x%pm;
        a2[i].y=0,b2[i].y=0;
    }
    for(int i=m+1;i<=len;i++)
        a2[i].x=a2[i].y=b2[i].x=b2[i].y=0;


    fft(a1,len,1);
    fft(b1,len,1);
    fft(a2,len,1);
    fft(b2,len,1);

    get(a1,a2,len,1ll*pm*pm%p);
    get(a1,b2,len,pm%p);
    get(a2,b1,len,pm%p);
    get(b1,b2,len,1);


    for(int i=0;i<=len;i++)
        a[i]=ans[i],ans[i]=0;

}


int main(void)
{
    init();
    int n;
    int pp=0;
    while(scanf("%d",&n)!=EOF)
    {
        int sum=0;
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&pc[i]);
        }
        for(int i=1;i<=n;i++)
        {
            if(i==1)
            {
                for(int j=0;j<=pc[i];j++)
                    aa[j]=inv[j];
                sum=pc[i];
            }
            else
            {
                for(int j=0;j<=pc[i];j++)
                    bb[j]=inv[j];
                get(sum,pc[i],aa,bb);
                sum+=pc[i];
            }
        }
        ll res=0;
        for(int i=1;i<=sum;i++)
            res=(res+fac[i]*aa[i])%mod;
        printf("Case %d: %lld\n",++pp,res);
    }

    return 0;

}

©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页