题目链接:https://codeforces.com/contest/665/problem/E
题意:让你找一些子串,让它们异或和大于等于k。求这样的字串有多少个。
思路:
a
i
⊕
a
i
+
1
⊕
.
.
.
.
.
⊕
a
n
=
(
a
1
⊕
a
2
⊕
.
.
.
.
.
⊕
a
n
)
⊕
(
a
1
⊕
a
2
⊕
.
.
.
.
.
⊕
a
i
−
1
)
\begin{aligned} a_i\oplus a_{i+1}\oplus.....\oplus a_n &= (a_1\oplus a_2\oplus.....\oplus a_n) \oplus (a_1\oplus a_2\oplus.....\oplus a_{i-1}) \end{aligned}
ai⊕ai+1⊕.....⊕an=(a1⊕a2⊕.....⊕an)⊕(a1⊕a2⊕.....⊕ai−1)
所以我们自然会想到,对于每一个
a
i
a_i
ai找以
a
i
a_i
ai为结尾的子串,就把
a
1
⊕
a
2
⊕
.
.
.
.
⊕
a
i
a_1\oplus a_2 \oplus....\oplus a_i
a1⊕a2⊕....⊕ai和
i
i
i之前的所有的异或前缀和做一次异或运算,就能得到,以i为结尾的所有子串的异或前缀和。我们把符合条件的答案加起来就好了。
但这样做,时间复杂度为
O
(
n
2
)
O(n^2)
O(n2),肯定不能通过本题,我们考虑优化。
我开始考虑了一下线段树,后来没做出来。我看了大佬的通过代码,发现可以用01字典树,对所有的异或前缀和进行拆位存储,这样,就很大程度地降低了时间复杂度。
我们用字典树,现在求以
a
i
a_i
ai为结尾符合答案的子串数量,假设前
i
−
1
i-1
i−1项的前缀和情况存在了字典树中,我们这里用
y
y
y表示。
也就是,我们要求
a
i
⊕
y
>
=
k
a_i \oplus y >= k
ai⊕y>=k的数量。
我们按位从高位到低位考虑:
如果
a
i
a_i
ai的当前位是0,k的当前为也是0,此时
y
y
y的当前为1的情况都成立,直接加入答案中;y当前位为0,就继续考虑向下考虑。
如果
a
i
a_i
ai的当前位为0,k的当前位为1,此时
y
y
y的当前位如果可以为1,就继续向下考虑。
如果
a
i
a_i
ai的当前位为1,k的当前位为0,此时
y
y
y的当前位为0的情况都成立,就直接加入答案;y当前位为1,就继续向下考虑。
如果
a
i
a_i
ai的当前位为1,k的当前位为1,此时
y
y
y的当前位为0,就继续向下考虑。
它的过程,类似于数位dp,遇到大于的情况直接加入答案,如果遇到等于的情况,就继续考虑下一位,因为等于的情况只有等到所有位都判断完,才能得出结论。如果遇到了小于的情况,直接忽略这种情况,如果当前位已经比预期小了,无论,更低的位如何选择,你都不会得出大于等于预期的情况。
代码如下:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
//字典树,异或前缀和
const int N = 3e7 + 10;
int tot = 0;
int trie[N][2];
int cnt[N];
int a[1000010];
int n, k;
void insert(int x){
int rt = 0;
for(int i = 30; i >= 0; i--){
int now = x >> i & 1;
if(!trie[rt][now]){
trie[rt][now] = ++tot;
}
rt = trie[rt][now];
cnt[rt]++;
}
}
int find(int x, int k){
//找目前所有的前缀中,异或x大于等于k的个数
int ans = 0, rt = 0;
for(int i = 30; i >= 0; i--){
int nowx = x >> i & 1;
int nowk = k >> i & 1;
if(!nowx){
if(!nowk){
ans += cnt[trie[rt][1]];
if(!trie[rt][0]){
return ans;
}
rt = trie[rt][0];
}
else{
if(!trie[rt][1]){
return ans;
}
rt = trie[rt][1];
}
}
else{
if(!nowk){
ans += cnt[trie[rt][0]];
if(!trie[rt][1]){
return ans;
}
rt = trie[rt][1];
}
else{
if(!trie[rt][0]){
return ans;
}
rt = trie[rt][0];
}
}
}
return ans + cnt[rt];//相等的情况不要漏掉
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; i++){
cin >> a[i];
a[i] = a[i]^a[i-1];
}
ll ans = 0;
insert(0);
for(int i = 1; i <= n; i++){
ans += 1LL*find(a[i], k);
insert(a[i]);
}
cout << ans << "\n";
return 0;
}