题面:
Let f(x) denote the number of 1s in the binary representation of x.
Now MianKing has a sequence a0...m−1 and he wants to know the number of integer x∈[0,L] satisfies that: ∀i∈[0,m−1],f(x+i) mod 2=ai
You need to help him calculate the answer.
详细的输入输出和样例见原题链接。
题意:
T 组测试数据,每组给定一个 m 和 L,以及一个下标从 0 开始的01数列 a,定义一个函数 f(x) 为 x 的二进制表示上 1 的个数。问有多少 x∈[0,L] 满足: ∀i∈[0,m−1],f(x+i) mod 2=ai。
题解:
观察数据范围,我们容易发现 m 并不大,0≤m≤100,而 100 的二进制表示为 1100100,总共有7位,而 x+i 后,可能会影响 x 二进制表示上 1 的个数,所以我们可以猜 x+i 只会影响 x 二进制表示上的低7位。但我们还得考虑如果 x+i 后,从第7位往高位进1的情况,在纸上模拟我们容易发现,进1也只会影响 x 的二进制表示上,从第8位开始,往高位连续的 n 个1,再 +1,总共 n+1 位数,如下图所示。
综上分析后,我们需要考虑以下两点
1、对于每个 x,记录从第8位往高位连续1的个数
2、对于每个 x+i,记录其二进制表示上1的个数
那对于那些二进制表示上不足8位的 x 呢?直接暴力枚举呗,反正最大才127(127的二进制表示为 1111111)。还有一个细节,注意到序列 a 是01序列,且题目让我们求 1 的个数的奇偶性,因此我们可以用异或运算替换加法运算来记录 1 的个数,0 表示偶数个 1,1 表示奇数个 1,以此来减少空间。最后,对于这种求一定区间内有多少数满足题目约束条件的问题,基本都用数位DP来解决,时间复杂度为 O( T*((logL)/7*60 + (2^7)*m) ),详细细节看代码。
代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
int T,m,L,f[61][2][2][2],a[100],b[61];
int get(int sum,int pre,int flag)
{
int x=flag?127:L%128;
int ans=0;
for(int i=0;i<=x;i++)
{
int g=1;//g 表示i是否符合题意
for(int j=0;j<m&&g;j++)
{
//__builtin_parity(x)函数,计算x二进制上1的个数
//不进位
if(i+j<128) g&=((__builtin_parity(i+j)^sum)==a[j]);
//进位
//1^pre,是因为进位1
else g&=((__builtin_parity(i+j-128)^sum^(1^pre))==a[j]);
}
ans+=g;
}
return ans;
}
int dp(int pos,int sum,int pre,int flag)
{
/*
pos 当前试填到第几位
sum 从最高位到第8位为止,总共有几个1,0表示有偶数个1,1表示有奇数个1
pre 表示从第8位开始往高位连续1的个数
flag 表示是否试填到原数的边界 0 到边界 1 没到
*/
//记忆化
if(~f[pos][sum][pre][flag]&&flag) return f[pos][sum][pre][flag];
//位数不足8的,直接暴力枚举
if(pos<=7) return f[pos][sum][pre][flag]=get(sum,pre,flag);
//x表示试填的上界
int x=flag?1:b[pos];
int ans=0;
for(int i=0;i<=x;i++)
{
int tp=(flag||i<x);
//sum^i,用^代替+
ans+=dp(pos-1,sum^i,(!pre)&i,tp);
}
if(flag) f[pos][sum][pre][flag]=ans;
return ans;
}
int calc(int x)
{
memset(b,-1,sizeof b );
int p=0;
//用数组b保存L的二进制表示
while(x)
{
b[++p]=x&1;
x>>=1;
}
return dp(p,0,0,0);
}
signed main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(0);
cin>>T;
while(T--)
{
cin>>m>>L;
memset(f,-1,sizeof f );//初始化
for(int i=0;i<m;i++) cin>>a[i];
cout<<calc(L)<<endl;
}
return 0;
}