CF1879D. Sum of XOR Functions
题意:给定一个长度为n的数组a,由非负整数组成。计算
∑
l
=
1
n
∑
r
=
1
n
f
(
l
,
r
)
∗
(
r
−
l
+
1
)
\sum_{l=1}^{n} \sum_{r=1}^{n}f(l,r)*(r-l+1)
l=1∑nr=1∑nf(l,r)∗(r−l+1)
其中
f
(
l
,
r
)
=
a
l
⊕
a
l
+
1
⊕
a
l
+
2
⊕
⋯
⊕
a
r
f(l,r) = a_l \oplus a_{l+1} \oplus a_{l+2} \oplus \cdots \oplus a_r
f(l,r)=al⊕al+1⊕al+2⊕⋯⊕ar
答案对 998244353 取模
思路:首先很容易想到暴力求解的思路,也就是先枚举长度,然后枚举起点,确定每个区间的左右边界,然后计算求和即可。但是这个复杂度太高了。
因此我们要看看有没有其他的办法,因为用到了异或运算,因此我们可以考虑是否可以用位运算来做,会发现,对于一个二进制数的每一位对答案的贡献都是独立的.
例如现在有一个长度为3的区间,其
f
(
l
,
r
)
f(l,r)
f(l,r)的值是10,对应的二进制数就是1010,那么答案就是
10
∗
3
=
30
10*3 = 30
10∗3=30,可以看成是二进制串中
10
=
(
1000
+
10
)
∗
3
10=(1000 + 10)*3
10=(1000+10)∗3括号内为二进制表示,即
1000
(
8
的二进制表示)
∗
3
+
10
(
2
的二进制表示)
∗
3
=
8
∗
3
+
2
∗
3
=
30
1000(8的二进制表示) *3 + 10(2的二进制表示)*3 =8*3 +2*3 =30
1000(8的二进制表示)∗3+10(2的二进制表示)∗3=8∗3+2∗3=30,因此我们可以按位考虑每一位对答案的贡献.
我们可以按位考虑,由
a
i
a_i
ai的范围可以知道,
a
i
a_i
ai的二进制长度最长为30,那么我们就可以看成是30个长度为n的二进制串,问题就转换成了:给定一个01串,对于所有包含奇数个1的区间
[
l
,
r
]
[l,r]
[l,r],计算它们的总区间长度。
那么这道题的解法就是,首先枚举每一个二进制位i,然后遍历数组 a a a的每一个数 j 的第i位二进制,统计 [ 1 j ] [1~j] [1 j]的前缀异或和中0和1出现的个数,对于每个固定的 j,要想知道区间 [ 1 , j ] [1,j] [1,j]中有多少个左边界 l,满足区间 [ l , j ] [l,j] [l,j]有奇数个1,那么只有当区间 [ 1 , l ] [1,l] [1,l]有偶数个1且区间 [ 1 , r ] [1,r] [1,r]有奇数个1,或者区间 [ 1 , l ] [1,l] [1,l]有奇数个1且区间 [ 1 , r ] [1,r] [1,r]有偶数个1,总共会有两种情况
因为还要统计总区间长度,这里有一个小技巧可以快速统计区间总长度,例如当前数 j的前缀异或和为1,那么就要找到前缀中所有异或和为0的位置,并且求出它们的区间长度,我们已经知道了前缀中0的个数cnt,我们只需统计下来所有为0的位置的下标总和sum,那么需要求的区间总长度就是
c
n
t
∗
j
−
s
u
m
cnt*j-sum
cnt∗j−sum 总区间长度知道了,那么答案就是
(
1
<
<
i
)
∗
(
c
n
t
∗
j
−
s
u
m
)
(1<<i)*(cnt*j-sum)
(1<<i)∗(cnt∗j−sum)
代码如下:
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <cmath>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int N = 1e5 + 10, mod = 998244353;
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
int n;
cin >> n;
vector<ll> s(n + 1);
for (int i = 1; i <= n; i++)
cin >> s[i], s[i] ^= s[i - 1];//首先求一遍前缀异或和
ll ans = 0;
for (int i = 0; i <= 30; i++)//枚举每一个二进制位
{
//注意cnt的初始边界
ll cnt[2] = {1, 0}, sum[2] = {0};
for (int j = 1; j <= n; j++)
{
int bit = s[j] >> i & 1;
if (bit)
{
ans = ((ans % mod + ((1 << i) % mod * (cnt[0] % mod * j % mod - sum[0] % mod) % mod) % mod) + mod) % mod;
//不加取模就等价于 ans +=(1 << i) * (cnt[0] * j - sum[0]);
//(cnt[0] * j - sum[0])可以快速统计出前缀中所有包含奇数个1的区间总和
//注意运算过程中可能会出现负数,所以最后有个+ mod % mod
}
else
{
ans = ((ans % mod + ((1 << i) % mod * (cnt[1] % mod * j % mod - sum[1] % mod) % mod) % mod) + mod) % mod;
//等价于 ans +=(1 << i) * (cnt[1] * j - sum[1]);
}
cnt[bit] = (cnt[bit] + 1) % mod;//分别统计前缀【1~j】中0和1出现的个数
sum[bit] = (sum[bit] + j) % mod;//分别统计前缀【1~j】中异或值为0和1的位置的下标总和
}
}
cout << ans % mod<< "\n";
return 0;
}
···