题意:给你n个数字,问你最多能够把序列分成几份,使得每段长度不超过l,且异或和不超过x,数据范围(1<= L<=N<=100000, 0<=X<268435456)。
01字典树是个好东西,可惜我没有=_=
最简单的dp方程就是dp[i][j]表示前i个分成j段是否可行,再优化一步可以变成dp[i]=j表示前i个最多分成j段,每次枚举之前的合法的位置k,dp[i]=max(dp[k])+1。
可是这样是n方的,求区间的异或值还要o(n),所以首先,由于[l,r]的亦或值是等于[1,r]^[1,l-1]的,所以我们可以先预处理出前缀异或和,这样转移就只是o(n)找合法位置了。
那么这个合法的位置怎么找,就是要用到01字典树了。
首先,我们01字典树里面存的是前缀异或和,且在每个节点多开两个数组cnt与val,表示当前前缀异或和下有cnt个位置,以及他们的dp值的最大值是val。
接着在我们在枚举i,只需查询01字典树里与当前前缀异或和sum[i]异或下来小于等于x的值即可,具体是下面这样做的:
1.如果当前i-l+1的位置曾经插入过一个dp值,就把它从01字典树中删除。
2.判断x的当前位是否为1,如果当前位为0,那么说明我们需要查找一个和sum[i]当前位相等的值,由于当前异或值在x的边界上,不能直接return 查询节点的val,所以继续向下递归,如果当前位为1,那么我们找与sum[i]异或值为0,1都可以,如果异或值为0,这一位已经不在x的边界上了,可以直接与sum[i]异或为0的那个儿子的val来更新答案,如果为1,同理在边界上,继续更新答案(有一点点数位dp的思想),复杂度为logx。
3.如果当前位置能找到一个与它异或值小于x的,那么就将它插入到01字典树中。
下附AC代码。
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define maxn 100005
using namespace std;
typedef long long ll;
const ll mod=268435456;
ll n,x,l,p,q,tot;
ll dp[maxn];
int son[maxn*32][2];
ll a[maxn],sum[maxn];
ll val[maxn*32],cnt[maxn*32];
void insert(ll now,ll pos,ll num)
{
if(pos==-1)
{
val[now]=dp[num];
return;
}
ll nex=(sum[num]>>pos)&1;
if(!son[now][nex]){son[now][nex]=++tot;cnt[tot]=0;val[tot]=-1;}
cnt[son[now][nex]]++;
insert(son[now][nex],pos-1,num);
val[now]=max(val[now],val[son[now][nex]]);
}
void del(ll now,ll pos,ll num)
{
if(pos==-1)
{
if(!cnt[now]) val[now]=-1;
return;
}
ll nex=(sum[num]>>pos)&1;
cnt[son[now][nex]]--;
del(son[now][nex],pos-1,num);
val[now]=val[son[now][nex]];
if(son[now][nex^1] && cnt[son[now][nex^1]])
val[now]=max(val[now],val[son[now][nex^1]]);
}
ll query(ll now,ll pos,ll num)
{
if(pos==-1) return val[now];
ll nex=((sum[num]>>pos)&1),maxx=((x>>pos)&1);
ll ans=-1;
if(maxx==1)
{
if(son[now][nex]&&cnt[son[now][nex]])
ans=max(ans,val[son[now][nex]]);
if(son[now][nex^1]&&cnt[son[now][nex^1]])
ans=max(ans,query(son[now][nex^1],pos-1,num));
}
else
{
if(son[now][nex] && cnt[son[now][nex]])
ans=max(ans,query(son[now][nex],pos-1,num));
}
return ans;
}
int main()
{
ll _;
scanf("%lld",&_);
while(_--)
{
scanf("%lld%lld%lld",&n,&x,&l);
tot=0;
memset(dp,0,sizeof(dp));
memset(val,-1,sizeof(val));
memset(son,0,sizeof(son));
memset(cnt,0,sizeof(cnt));
scanf("%lld%lld%lld",&a[1],&p,&q);
sum[1]=a[1];
for(ll i=2;i<=n;i++)
{
a[i]=(a[i-1]*p+q)%mod;
sum[i]=a[i]^sum[i-1];
}
insert(0,30,0);
for(ll i=1;i<=n;i++)
{
if(i>l && dp[i-l-1]) del(0,30,i-l-1);
if(i==l+1) del(0,30,0);
ll now=query(0,30,i);
if(now>=0)
{
dp[i]=now+1;
insert(0,30,i);
}
}
printf("%lld\n",dp[n]);
}
}