题意
给出n个数a[1…n],要求从其中选出两个子集,要求这两个子集至少有一个不为空集且交集为空,并且它们的异或和相等。
n,a[i]<=1000000
分析
注意到两个子集的异或和相等,那么他们并集的异或和必然为0。那么现在就变成了求所有异或和为0的子集的价值和,其中定义一个大小为
s
s
s的子集的价值
2
s
2^s
2s。
那么我们可以看成每个元素的贡献是
2
2
2,一个集合的贡献就是它里面元素贡献的乘积。
设每一个数a[i]的集合幂级数为
f
(
x
)
i
=
1
+
2
x
a
[
i
]
f(x)_i=1+2x^{a[i]}
f(x)i=1+2xa[i],那么我们要求的就是
∏
i
=
1
n
f
(
x
)
i
\prod_{i=1}^nf(x)_i
i=1∏nf(x)i
这个多项式的第0项系数,其中乘法定义为集合对称差卷积。
这样我们就得到了40分的高分。
考虑如何优化。
很自然的想法是把每一个集合幂级数都FWT一下,全部乘起来之后再逆FWT回去。这样做的话据说可以得到0分的高分。
注意到这些集合幂级数都是
1
+
2
x
a
[
i
]
1+2x^{a[i]}
1+2xa[i]的形式,猜想一下把这样的集合幂级数FWT之后会不会有什么特殊的性质呢?答案是肯定的。
注意到FWT之后每一位都是-1或3,且FWT的和等于和的FWT。
那么我们可以先把每个集合幂级数加起来,做一次FWT。
对于FWT之后为s的某一项,设这位有x个数为-1,那么3的数量就有(n-x)个。列出方程
3
(
n
−
x
)
−
x
=
s
3(n-x)-x=s
3(n−x)−x=s,解得
x
=
3
n
−
s
4
x=\frac{3n-s}{4}
x=43n−s。那么FWT后这一项的值就是
(
−
1
)
x
3
n
−
x
(-1)^x3^{n-x}
(−1)x3n−x。
求出每一项之后再逆FWT回去即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=1000005;
const int MOD=998244353;
const int ny2=(MOD+1)/2;
int n,a[N],bin[22],f[1048570],g[1048570],po[N];
void fwt(int *f,int l,int r)
{
if (l==r) return;
int len=(r-l+1)/2,mid=(l+r)/2;
for (int i=l;i<=mid;i++)
{
int u=f[i],v=f[i+len];
f[i]=u+v;f[i+len]=u-v;
}
fwt(f,l,mid);fwt(f,mid+1,r);
}
void dwt(int *f,int l,int r)
{
if (l==r) return;
int len=(r-l+1)/2,mid=(l+r)/2;
for (int i=l;i<=mid;i++)
{
int u=f[i],v=f[i+len];
f[i]=(LL)(u+v)*ny2%MOD;f[i+len]=(LL)(u-v)*ny2%MOD;
}
dwt(f,l,mid);dwt(f,mid+1,r);
}
int main()
{
bin[0]=po[0]=1;
for (int i=1;i<=20;i++) bin[i]=bin[i-1]*2;
scanf("%d",&n);
for (int i=1;i<=n;i++) po[i]=(LL)po[i-1]*3%MOD;
for (int i=1;i<=n;i++)
{
int x;scanf("%d",&x);
f[0]++;f[x]+=2;
}
fwt(f,0,bin[20]-1);
for (int i=0;i<bin[20];i++)
{
int x=(LL)(n*3-f[i])/4;
if (x&1) f[i]=-po[n-x];
else f[i]=po[n-x];
}
dwt(f,0,bin[20]-1);
printf("%d",(f[0]+MOD-1)%MOD);
return 0;
}