题意:给定一个长度为n的数组和一个数k,求有多少个区间[l,r]满足a[l]^a[l+1]^..^a[r]>=k。
分析:我们设sum[i]=a[1]^..^a[i],那么对于每个sum[i]我们要找出有多少个sum[j]^sum[i]>=k且j<i。我们将sum[0]~sum[i-1]全部存入字典树,然后每次在树中查询即可,为了方便统计我们改为sum[j]^sum[i]>k-1变成严格大于能直接处理出所有的答案而不用再查找有多少个等于k的情况。详见代码。
代码:
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<vector>
#include<string>
#include<stdio.h>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=1000010;
const int M=1e5+7;
const int HASHSIZE=3e5+9;
const int mod=100000000;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=1000000007;
const int INF=~0u>>1;
const ll MAX=1ll<<55;
const double pi=acos(-1.0);
typedef double db;
typedef unsigned long long ull;
int tot,a[N],K[30],bit[30],siz[20*N],tr[20*N][2];
void add(int x) {
int i,d,w=0;siz[0]++;
for (i=29;i>=0;i--) {
d=(x&bit[i])==bit[i];
if (tr[w][d]) w=tr[w][d],siz[w]++;
else tr[w][d]=++tot,w=tot,siz[w]=1;
}
}
int get(int x) {
int i,b,d,w=0,ret=0;
for (i=29;i>=0;i--) {
d=(x&bit[i])==bit[i];
if (!K[i]&&tr[w][!d]) ret+=siz[tr[w][!d]];
if (K[i]) w=tr[w][!d];
else w=tr[w][d];
if (!w) break ;
}
return ret;
}
int main()
{
int i,n,k;
ll ans=0;
scanf("%d%d", &n, &k);
for (i=0;i<30;i++) bit[i]=1<<i;
k--;a[0]=tot=0;add(a[0]);
for (i=0;i<30;i++) K[i]=(k&bit[i])==bit[i];
for (i=1;i<=n;i++) scanf("%d", &a[i]),a[i]^=a[i-1];
for (i=1;i<=n;i++) {
ans+=(ll)get(a[i]);add(a[i]);
}
printf("%I64d\n", ans);
return 0;
}