给你一个整数数组 arr 。
现需要从数组中取三个下标 i、j 和 k
,其中 (0 <= i < j <= k < arr.length)
。
a 和 b 定义如下:
- a = arr[i] ^ arr[i + 1] ^ … ^ arr[j - 1]
- b = arr[j] ^ arr[j + 1] ^ … ^ arr[k]
注意:^
表示 按位异或 操作。
请返回能够令 a == b
成立的三元组 (i, j , k)
的数目。
示例 1:
输入:arr = [2,3,1,6,7]
输出:4
解释:满足题意的三元组分别是 (0,1,2), (0,2,2), (2,3,4) 以及 (2,4,4)
示例 2:
输入:arr = [1,1,1,1,1]
输出:10
示例 3:
输入:arr = [2,3]
输出:0
示例 4:
输入:arr = [1,3,5,7,9]
输出:3
示例 5:
输入:arr = [7,11,12,9,5,2,7,17,22]
输出:8
提示:
1 <= arr.length <= 300
1 <= arr[i] <= 10^8
解题思路: 前缀和
用 ⊕ 表示按位异或运算。
定义长度为 n 的数组 arr 的异或前缀和:
由该定义可得:
这是一个关于 Si
的递推式,根据该递推式我们可以用O(n)的时间得到数组arr 的异或前缀和数组。对于两个下标不同的异或前缀和 Si
和Sj
设 0<i<j && 0<i<j
,有
由于异或运算满足结合律和交换律,且任意数异或自身等于 00,上式可化简为:
从而,数组 arr 的子区间 [i,j]
[i,j]
的元素异或和为可表示为
因此问题中的 a 和 b 可表示为
若 a==b,则有
即:
解法1:(三重循环)
class Solution {
public int countTriplets(int[] arr) {
int len = arr.length;
int[] str = new int[len+1];
for(int i = 0; i < len; ++i){
str[i+1] = str[i] ^ arr[i];
}
int result = 0;
for(int i = 0; i < len; ++i){
for(int j = i+1; j < len; ++j){
for(int k = j; k < len; ++k){
if(str[i] == str[k+1]){
++ result;
}
}
}
}
return result;
}
}
解法2:(两重循环)
class Solution {
public int countTriplets(int[] arr) {
int len = arr.length;
int[] str = new int[len+1];
for(int i = 0; i < len; ++i){
str[i+1] = str[i] ^ arr[i];
}
int result = 0;
for(int i = 0; i < len; ++i){
for(int j = i+1; j < len; ++j){
if(str[i] == str[j+1]){
result += j - i;
}
}
}
return result;
}
}
(二重循环)代码优化:
class Solution {
public int countTriplets(int[] arr) {
int len = arr.length;
int[] str = new int[len+1];
int result = 0;
for(int i = 0; i < len; ++i){
int temp = arr[i];
for(int j = i+1; j < len; ++j){
temp ^= arr[j];
if(temp == 0){
result += j - i;
}
}
}
return result;
}
}
解法3: 官方解答
class Solution {
public int countTriplets(int[] arr) {
int n = arr.length;
int[] s = new int[n + 1];
for (int i = 0; i < n; ++i) {
s[i + 1] = s[i] ^ arr[i];
}
Map<Integer, Integer> cnt = new HashMap<Integer, Integer>();
Map<Integer, Integer> total = new HashMap<Integer, Integer>();
int ans = 0;
for (int k = 0; k < n; ++k) {
if (cnt.containsKey(s[k + 1])) {
ans += cnt.get(s[k + 1]) * k - total.get(s[k + 1]);
}
cnt.put(s[k], cnt.getOrDefault(s[k], 0) + 1);
total.put(s[k], total.getOrDefault(s[k], 0) + k);
}
return ans;
}
}