题目大意:从 N 个数字中选 1 到 k 个,使得这些数字的与运算结果为 s, 或运算结果为 t,为有多少种选择方案。
首先我们只关注 集合a 中满足 (x & s) == s && (x & t) == x
的元素 x,称这个集合为 X:{x},令 t = t ^ s
,通过预处理,将 t 中为 1的位提出来,此时问题等价于:s = 0, t =
2
L
−
1
2^L - 1
2L−1。
这相当于要从集合 X 中选出 [1,k] 个元素作为一个新的集合,使得这个新集合中在这 L 位上每一位至少有一个0,至少有一个1,等价于新集合中每一个元素在这 L 位都是不完全相同的。
通过二项式反演来解决这个问题:
设 f(n):恰好 n 位不完全相同的选择方案;g(n):最多 n 位不完全相同的选择方案
显然
f
(
L
)
f(L)
f(L) 是所求解
g
(
n
)
=
∑
i
=
0
n
C
(
n
,
i
)
∗
f
(
i
)
g(n) = \sum_{i = 0}^nC(n,i)*f(i)
g(n)=i=0∑nC(n,i)∗f(i)
可以反演得到:
f
(
n
)
=
∑
i
=
0
n
(
−
1
)
n
−
i
∗
C
(
n
,
i
)
∗
g
(
i
)
f(n) = \sum_{i = 0}^n(-1)^{n-i}*C(n,i)*g(i)
f(n)=i=0∑n(−1)n−i∗C(n,i)∗g(i)
问题在于如何求
C
(
n
,
i
)
∗
g
(
i
)
C(n,i) * g(i)
C(n,i)∗g(i):
定义
h
(
U
)
:
h(U):
h(U): 所选集合中与 U进行与运算结果相同的选择方案数。
由于与运算结果相同,等价于选出来的集合中:满足 U 中为 1 的位,这些数在这一位的值相同;U 中为 0 的位,这些数在这些位的取值可能相同也可能不同。
假设 U 有 x x x 位为 1,那么相当于选出来的集合最多有 L - x 位不相同,这恰好是 g(L - x) 的定义
因此 f ( n ) = ∑ i = 0 n ( − 1 ) n − i ∗ ∑ U ∈ ( U 中 有 i 个 0 ) h ( U ) f(n) = \displaystyle\sum_{i = 0}^n(-1)^{n-i}*\sum_{U\in{(U中有i个0)}}h(U) f(n)=i=0∑n(−1)n−i∗U∈(U中有i个0)∑h(U)
改变枚举项,先枚举 U,可以得到 f ( L ) = ∑ U = 0 2 L − 1 ( − 1 ) p o p c o u n t ( U ) ∗ h ( U ) f(L) = \displaystyle\sum_{U=0}^{2^L-1}(-1)^{popcount(U)}*h(U) f(L)=U=0∑2L−1(−1)popcount(U)∗h(U)
其中 p o p c o u n t ( U ) = U popcount(U) = U popcount(U)=U中1的个数
h ( U ) h(U) h(U) 可以通过遍历一遍数组求得。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e2 + 10;
int a[maxn], n, k, s, t, tot;
int vis[300010];
ll C[51][51], D[51][51];
int main() {
scanf("%d%d%d%d",&n,&k,&s,&t);
for (int i = 0; i <= n; i++)
D[i][0] = C[i][0] = 1;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= i; j++) {
C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
D[i][j] = D[i][j - 1] + C[i][j];
}
for (int i = 1, x; i <= n; i++) {
scanf("%d",&x);
if ((x & s) == s && (x & t) == x)
a[++tot] = x;
}
t ^= s;
int cnt = 0;
for (int j = 0; j < 18; j++)
if (t >> j & 1) cnt++;
//printf("%d %d\n",cnt,tot);
for (int i = 1; i <= tot; i++) {
int val = 1, sum = 0;
for (int j = 0; j < 18; j++) {
if ((a[i] >> j & 1) && (t >> j & 1))
sum += val;
if (t >> j & 1)
val *= 2;
}
a[i] = sum;
}
ll ans = 0;
for (int i = 0; i < (1 << cnt); i++)
vis[i] = 0;
for (int t = 0; t < (1 << cnt); t++) {
int s = 1;
for (int j = 0; j < cnt; j++)
if (t >> j & 1) s *= -1;
ll sum = 0;
for (int i = 1; i <= tot; i++) {
sum -= D[vis[a[i] & t]][min(vis[a[i] & t],k)];
vis[a[i] & t]++;
sum += D[vis[a[i] & t]][min(vis[a[i] & t],k)];
}
for (int i = 1; i <= tot; i++)
vis[a[i] & t]--;
ans += s * sum;
}
printf("%lld\n",ans);
return 0;
}