题意
给定n行m列数,对于 [0,2^k-1] 内的数x求 C o u n t ( x ) = ∑ i = 1 n ∏ j = 1 m [ x & a i , j 的 二 进 制 表 达 式 奇 数 ] Count(x) =\sum_{i=1}^{n} \prod_{j=1}^{m}[x\&a_{i,j}的二进制表达式奇数] Count(x)=∑i=1n∏j=1m[x&ai,j的二进制表达式奇数]
分析
- 求 Count(x)
来自Qls, ∣ a i , j ∧ X ∣ |a_{i,j}\land X| ∣ai,j∧X∣代表 a i , j & X a_{i,j}\&X ai,j&X的奇偶性,如果连乘式中有一个为偶, 整个就为0
- 把连乘展开
得到 ∏ j = 1 m ( 1 − ( − 1 ) ∣ a i , j ∧ x ∣ ) = 1 + ∑ j 不 等 于 k ( − 1 ) ∣ a i , j ∧ x ∣ + ∣ a i , k ∧ x ∣ − . . . . \prod_{j=1}^{ m}(1-(-1)^{|a_{i,j}\land x|})=1+\sum_{j 不等于 k}(-1)^{|a_{i,j}\land x|+|a_{i,k}\land x|}-.... j=1∏m(1−(−1)∣ai,j∧x∣)=1+j不等于k∑(−1)∣ai,j∧x∣+∣ai,k∧x∣−....
我们知道
∣ ( i ∧ j ) ∣ + ∣ ( i ∧ k ) ∣ 的 奇 偶 性 = ∣ ( i ∧ ( j ⨁ k ) ) ∣ 的 奇 偶 性 |(i\land j)|+|(i\land k)| 的奇偶性= |(i\land (j \bigoplus k))|的奇偶性 ∣(i∧j)∣+∣(i∧k)∣的奇偶性=∣(i∧(j⨁k))∣的奇偶性
∏
j
=
1
m
(
1
−
(
−
1
)
∣
a
i
,
j
∧
x
∣
)
=
1
+
∑
j
不
等
于
k
(
−
1
)
∣
x
∧
(
a
i
,
j
⨁
a
i
,
k
)
∣
+
.
.
.
.
\prod_{j=1}^{ m}(1-(-1)^{|a_{i,j}\land x|})=1+\sum_{j 不等于 k}(-1)^{|x\land (a_{i,j} \bigoplus a_{i,k} ) |}+....
j=1∏m(1−(−1)∣ai,j∧x∣)=1+j不等于k∑(−1)∣x∧(ai,j⨁ai,k)∣+....
3. FWT_XOR 正好就是我们需要的
(C1表示i&j奇偶性为0,C2表示i&j的奇偶性为1)
参考代码
const LL mod = 1e9 + 7;
LL qpow(LL a, LL b) {LL s = 1; while (b > 0) {if (b & 1)s = s * a % mod; a = a * a % mod; b >>= 1;} return s;}
LL gcd(LL a, LL b) {return b ? gcd(b, a % b) : a;}
int dr[2][4] = {1, -1, 0, 0, 0, 0, -1, 1};
typedef pair<int, int> P;
// 异或
void FWT(int *a, int N, int opt) {
const int inv2 = qpow(2, mod - 2);
// j是区间开始点,i是区间距离,k是具体位置,j+k,i+j+k就是在a数组中的坐标
for (int i = 1; i < N; i <<= 1) {
for (int p = i << 1, j = 0; j < N; j += p) {
for (int k = 0; k < i; ++k) {
LL X = a[j + k], Y = a[i + j + k];
a[j + k] = (X + Y) % mod;
a[i + j + k] = (X + mod - Y) % mod;
if (opt == -1) a[j + k] = 1ll * a[j + k] * inv2 % mod, a[i + j + k] = 1ll * a[i + j + k] * inv2 % mod;
}
}
}
}
const int maxn = 1 << 21;
int a[maxn];
int v[maxn];
void dfs(int *a, int x, int m, int sign, int t) {
if (x > m) {
v[t] += sign;
return ;
}
dfs(a, x + 1, m, sign, t);
dfs(a, x + 1, m, -sign, t ^ a[x]);
}
int main(void)
{
int n, m, k;
while (cin >> n >> m >> k) {
for (int i = 0; i < (1 << k); ++i)
v[i] = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%d", &a[j]);
}
dfs(a, 1, m, 1, 0);
}
int N = 1 << k;
FWT(v, N, 1);
LL sum = 0;
LL inv = qpow(1<<m, mod - 2);
LL tmp = 1;
for (int i = 0; i < N; ++i) {
sum ^= v[i] * tmp % mod * inv % mod;
tmp = tmp*3%mod;
}
cout << sum << endl;
}
return 0;
}