题意:给一串数字,问有多少个序列能够组成2048
思路分析:DP+组合数学,如果就一个DP就还好做,加上组合数学,那我就晕了。思路是参考大神的。dp[i][j]表示2^i选了j个有多少个序列,状态转移dp[i][j+k/2] += C(n,j)*dp[i-1][k]。还有组合数是怎么用的很关键,看代码吧。
这题非常卡时间,有两点注意下:1、这题的输入数据量很大,要进行输入优化;2、取模运算非常耗时,尽量少用取模运算,我用add来减少取模运算
代码如下:
#include<iostream>
#include<algorithm>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#define LL __int64
#define MOD 998244353
#define N 100005
#define M 2050
#define inf 0x7ffffff
#define eps 1e-9
#define pi acos(-1.0)
using namespace std;
int a[20],vis[M],cnt_c[20];
LL p[N],pw[N],pm[N];
LL dp[20][M];
LL pow_mul(LL a,int n)
{
LL b = 1;
while(n)
{
if(n&1) b = b*a%MOD;
n >>= 1;
a = a*a%MOD;
}
return b;
}
void init()
{
memset(vis,-1,sizeof(vis));
int i;
for(i = 0; i < 12; i++)
{
vis[1<<i] = i;
cnt_c[i] = 2048/(1<<i);
}
p[0] = 1;
for(i = 1; i < N; i++)
p[i] = p[i-1]*2%MOD;
pm[0] = pw[0] = 1;
for(i = 1; i < N; i++)
{
pm[i] = pm[i-1]*i%MOD;
pw[i] = pow_mul(pm[i],MOD-2);
}
}
LL C(int n,int m)
{
return pm[n]*pw[m]%MOD*pw[n-m]%MOD;
}
void add(LL &a,LL b)
{
a += b;
if(a >= MOD) a-= MOD;
}
void solve()
{
memset(dp,0,sizeof(dp));
int i,j,k;
int num = min(a[0],cnt_c[0]);
LL sum = 0,val;
for(i = 0; i <= num; i++){
val = C(a[0],i);
dp[0][i] = val;
add(sum,val);
}
if(cnt_c[0] < a[0])
dp[0][cnt_c[0]] = (dp[0][cnt_c[0]] + p[a[0]] - sum + MOD)%MOD;
for(i = 1; i < 12; i++)
{
num = min(a[i],cnt_c[i]);
sum = 0;
for(j = 0; j <= num; j++)
{
val = C(a[i],j);
add(sum,val);
for(k = 0; k <= cnt_c[i-1]; k++)
if(dp[i-1][k])
{
int temp = min(j+k/2,cnt_c[i]);
add(dp[i][temp],val*dp[i-1][k]%MOD);
}
}
if(a[i] > cnt_c[i]){
val = (p[a[i]]-sum + MOD)%MOD;
for(j = 0; j <= cnt_c[i-1]; j++)
if(dp[i-1][j])
add(dp[i][cnt_c[i]],dp[i-1][j]*val%MOD);
}
}
}
void read(int &x)
{
char ch;
while((ch=getchar())<'0'||ch>'9');
x=ch-'0';
while((ch=getchar())>='0'&&ch<='9')
x=x*10+ch-'0';
}
int main()
{
//freopen("input.txt","r",stdin);
//freopen("output.txt","w",stdout);
int n,cnt,cas = 1;
init();
while(scanf("%d",&n) && n)
{
int i;
cnt = 0;
memset(a,0,sizeof(a));
for(i = 0; i < n; i++)
{
int x;
read(x);
if(vis[x] == -1) cnt++;
else a[vis[x]]++;
}
printf("Case #%d: ",cas++);
solve();
printf("%I64d\n",dp[11][1]*p[cnt]%MOD);
}
return 0;
}