线性变换>->
类比FFT
对于这类东西,我们考虑
t
f
(
A
)
t
f
(
B
)
=
t
f
(
A
∗
B
)
tf(A)tf(B)=tf(A*B)
tf(A)tf(B)=tf(A∗B),其中*为某二元运算,tf为线性变换,设
C
=
A
∗
B
C=A*B
C=A∗B
形象的我们可以把tf认为
t
f
(
A
)
i
=
∑
j
=
0
n
A
j
f
(
n
,
i
,
j
)
tf(A)_{i}=\sum_{j=0}^{n}A_{j}f(n,i,j)
tf(A)i=∑j=0nAjf(n,i,j),其中f(i,j)是一个函数
所以,我们有
t
f
(
A
)
i
t
f
(
B
)
i
=
t
f
(
A
∗
B
)
i
tf(A)_{i}tf(B)_{i}=tf(A*B)_{i}
tf(A)itf(B)i=tf(A∗B)i
(
∑
j
=
0
n
A
j
f
(
n
,
i
,
j
)
)
(
∑
j
=
0
n
B
j
f
(
n
,
i
,
j
)
)
=
∑
j
=
0
n
C
j
f
(
n
,
i
,
j
)
(\sum_{j=0}^{n}A_{j}f(n,i,j))(\sum_{j=0}^{n}B_{j}f(n,i,j))=\sum_{j=0}^{n}C_{j}f(n,i,j)
(∑j=0nAjf(n,i,j))(∑j=0nBjf(n,i,j))=∑j=0nCjf(n,i,j)
f
(
n
,
i
,
j
)
f
(
n
,
i
,
k
)
=
f
(
n
,
i
,
j
∗
k
)
f(n,i,j)f(n,i,k)=f(n,i,j*k)
f(n,i,j)f(n,i,k)=f(n,i,j∗k)
如果是FFT那么我们有
f
(
n
,
i
,
j
)
=
(
w
n
i
)
j
f(n,i,j)=(w_{n}^{i})^{j}
f(n,i,j)=(wni)j,易证明
f
(
n
,
i
,
j
)
f
(
n
,
i
,
k
)
=
f
(
n
,
i
,
j
∗
k
)
f(n,i,j)f(n,i,k)=f(n,i,j*k)
f(n,i,j)f(n,i,k)=f(n,i,j∗k)
现在我们需要位运算卷积
or卷积
我们定义
f
(
n
,
i
,
j
)
=
[
(
j
∣
i
)
=
=
i
]
f(n,i,j)=[(j|i)==i]
f(n,i,j)=[(j∣i)==i],显然满足
f
(
n
,
i
,
j
)
f
(
n
,
i
,
k
)
=
f
(
n
,
i
,
j
∗
k
)
f(n,i,j)f(n,i,k)=f(n,i,j*k)
f(n,i,j)f(n,i,k)=f(n,i,j∗k),所以我们有
t
f
(
A
)
i
=
∑
j
=
0
n
A
j
[
(
j
∣
i
)
=
=
i
]
tf(A)_{i}=\sum_{j=0}^{n}A_{j}[(j|i)==i]
tf(A)i=∑j=0nAj[(j∣i)==i]
接下来我们考虑怎么计算tf(A),貌似可以直接子集和…
我们记
A
0
,
A
1
A_{0},A_{1}
A0,A1表示A的低
2
n
−
1
2^{n-1}
2n−1位与高
2
n
−
1
2^{n-1}
2n−1位,这时我们发现
t
f
(
A
)
=
(
t
f
(
A
0
)
,
t
f
(
A
1
)
+
t
f
(
A
0
)
)
tf(A)=(tf(A_{0}),tf(A_{1})+tf(A_{0}))
tf(A)=(tf(A0),tf(A1)+tf(A0))
定义
I
t
f
为
t
f
Itf为tf
Itf为tf的逆变换
故
I
f
(
n
,
i
,
j
)
=
[
j
&
i
=
=
i
]
(
−
1
)
(
c
o
u
n
t
(
j
)
−
c
o
u
n
t
(
i
)
)
If(n,i,j)=[j\&i==i](-1)^{(count(j)-count(i))}
If(n,i,j)=[j&i==i](−1)(count(j)−count(i))
I
t
f
(
A
)
=
(
I
t
f
(
A
0
)
,
I
t
f
(
A
1
)
−
I
t
f
(
A
0
)
)
Itf(A)=(Itf(A_{0}),Itf(A_{1})-Itf(A_{0}))
Itf(A)=(Itf(A0),Itf(A1)−Itf(A0))
and卷积
我们定义
f
(
n
,
i
,
j
)
=
[
j
&
i
=
=
i
]
f(n,i,j)=[j\&i==i]
f(n,i,j)=[j&i==i],显然满足
f
(
n
,
i
,
j
)
f
(
n
,
i
,
k
)
=
f
(
n
,
i
,
j
∗
k
)
f(n,i,j)f(n,i,k)=f(n,i,j*k)
f(n,i,j)f(n,i,k)=f(n,i,j∗k),所以
t
f
(
A
)
i
=
∑
j
=
0
n
A
j
[
j
&
i
=
=
i
]
tf(A)_{i}=\sum_{j=0}^{n}A_{j}[j\&i==i]
tf(A)i=∑j=0nAj[j&i==i]
计算
t
f
(
A
)
tf(A)
tf(A),超集和
t
f
(
A
)
=
(
t
f
(
A
0
)
+
t
f
(
A
1
)
,
t
f
(
A
1
)
)
tf(A)=(tf(A_{0})+tf(A_{1}),tf(A_{1}))
tf(A)=(tf(A0)+tf(A1),tf(A1))
故
I
f
(
n
,
i
,
j
)
=
[
j
∣
i
=
=
i
]
(
−
1
)
c
o
u
n
t
(
i
)
−
c
o
u
n
t
(
j
)
If(n,i,j)=[j|i==i](-1)^{count(i)-count(j)}
If(n,i,j)=[j∣i==i](−1)count(i)−count(j)
I
t
f
(
A
)
=
(
I
t
f
(
A
0
)
−
I
t
f
(
A
1
)
,
I
f
t
(
A
1
)
)
Itf(A)=(Itf(A_{0})-Itf(A_{1})_,Ift(A_{1}))
Itf(A)=(Itf(A0)−Itf(A1),Ift(A1))
xor卷积
我们定义
f
(
n
,
i
,
j
)
=
(
−
1
)
c
o
u
n
t
(
j
&
i
)
f(n,i,j)=(-1)^{count(j\&i)}
f(n,i,j)=(−1)count(j&i),显然满足
f
(
n
,
i
,
j
)
f
(
n
,
i
,
k
)
=
f
(
n
,
i
,
j
∗
k
)
f(n,i,j)f(n,i,k)=f(n,i,j*k)
f(n,i,j)f(n,i,k)=f(n,i,j∗k)
所以
t
f
(
A
)
i
=
∑
j
=
0
n
A
j
(
−
1
)
c
o
u
n
t
(
j
&
i
)
tf(A)_{i}=\sum_{j=0}^{n}A_{j}(-1)^{count(j\&i)}
tf(A)i=∑j=0nAj(−1)count(j&i)
t
f
(
A
)
=
(
t
f
(
A
0
)
+
t
f
(
A
1
)
,
t
f
(
A
0
)
−
t
f
(
A
1
)
)
tf(A)=(tf(A_{0})+tf(A_{1}),tf(A_{0})-tf(A_{1}))
tf(A)=(tf(A0)+tf(A1),tf(A0)−tf(A1))
故
I
f
(
n
,
i
,
j
)
=
1
n
∑
j
=
0
n
A
j
(
−
1
)
c
o
u
n
t
(
j
&
i
)
If(n,i,j)=\frac{1}{n}\sum_{j=0}^{n}A_{j}(-1)^{count(j\&i)}
If(n,i,j)=n1∑j=0nAj(−1)count(j&i)
I
t
f
(
A
)
=
(
t
f
(
A
0
)
+
t
f
(
A
1
)
2
,
t
f
(
A
0
)
−
t
f
(
A
1
)
2
)
Itf(A)=(\frac{tf(A_{0})+tf(A_{1})}{2},\frac{tf(A_{0})-tf(A_{1})}{2})
Itf(A)=(2tf(A0)+tf(A1),2tf(A0)−tf(A1))
注:以上用线性变换来解释fwt,貌似有点不太行
具体可以看2015年论文,集合幂级数貌似靠谱一点
【模板】快速沃尔什变换
#include<bits/stdc++.h>
#define ll long long
#define pb(x) push_back(x)
using namespace std;
const int mod=998244353;
typedef vector<int> poly;
poly a_or,b_or,a_and,b_and,a_xor,b_xor;
int n,x,inv2;
int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int mul(int x,int y){return (ll)x*y%mod;}
int ksm(int x,int y){
int ans=1;
for (;y;y>>=1,x=mul(x,x)) if (y&1) ans=mul(ans,x);
return ans;
}
inline poly operator*(poly a,poly b){
for (int i=0;i<a.size();i++) a[i]=mul(a[i],b[i]);
return a;
}
void Fwt_or(poly &a,int t){
for (int i=1;i<(1<<n);i<<=1)
for (int j=0,p=(i<<1);j<(1<<n);j+=p)
for (int k=j;k<j+i;k++) if (t==1) a[k+i]=(a[k+i]+a[k])%mod;
else a[k+i]=dec(a[k+i],a[k]);
}
void Fwt_and(poly &a,int t){
for (int i=1;i<(1<<n);i<<=1)
for (int j=0,p=(i<<1);j<(1<<n);j+=p)
for (int k=j;k<j+i;k++) if (t==1) a[k]=add(a[k+i],a[k]);
else a[k]=dec(a[k],a[k+i]);
}
void Fwt_xor(poly &a,int t){
for (int i=1;i<(1<<n);i<<=1)
for (int j=0,p=(i<<1);j<(1<<n);j+=p)
for (int k=j;k<j+i;k++) {
int x=a[k],y=a[k+i];
if (t==1) a[k]=add(x,y),a[k+i]=dec(x,y);
else a[k]=mul(add(x,y),inv2),a[k+i]=mul(dec(x,y),inv2);
}
}
int main(){
scanf("%d",&n);
for (int i=0;i<(1<<n);i++) scanf("%d",&x),a_or.pb(x),a_and.pb(x),a_xor.pb(x);
for (int i=0;i<(1<<n);i++) scanf("%d",&x),b_or.pb(x),b_and.pb(x),b_xor.pb(x);
inv2=ksm(2,mod-2);
Fwt_or(a_or,1); Fwt_or(b_or,1); a_or=a_or*b_or; Fwt_or(a_or,-1);
Fwt_and(a_and,1); Fwt_and(b_and,1); a_and=a_and*b_and; Fwt_and(a_and,-1);
Fwt_xor(a_xor,1); Fwt_xor(b_xor,1); a_xor=a_xor*b_xor; Fwt_xor(a_xor,-1);
for (int i=0;i<(1<<n);i++) printf("%d ",a_or[i]);
printf("\n");
for (int i=0;i<(1<<n);i++) printf("%d ",a_and[i]);
printf("\n");
for (int i=0;i<(1<<n);i++) printf("%d ",a_xor[i]);
}
【UNR #2】黎明前的巧克力
题意:
你现在有一个数集T,要从中选出一个子集s(s不为空),如果xor为0则对答案的贡献为2|s|否则不对答案产生贡献
解析:就是求
∏
i
=
1
n
(
1
+
2
x
a
i
)
\prod_{i=1}^{n}(1+2x^{a_{i}})
∏i=1n(1+2xai)的第0项系数,这里的
∏
\prod
∏为异或卷积
我们可以考虑把每个
(
1
+
2
x
a
i
)
(1+2x^{a_{i}})
(1+2xai)Fwt一下,再相乘,再Fwt回去,这样时间复杂度为
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)
我们考虑对
(
1
+
2
x
a
i
)
(1+2x^{a_{i}})
(1+2xai)Fwt本质上每位只会是3或-1,那么最后我们只需要求出每一位上有多少个3,有多少个-1即可.
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=20;
const int mod=998244353;
int f[1<<M],fac[1<<M];
int n,x,inv2;
int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int mul(int x,int y){return (ll)x*y%mod;}
int ksm(int x,int y){
int ans=1;
for (;y;y>>=1,x=mul(x,x)) if (y&1) ans=mul(ans,x);
return ans;
}
void Fwt(int *a,int opt){
for (int i=1;i<(1<<M);i<<=1)
for (int j=0,p=(i<<1);j<(1<<M);j+=p)
for (int k=j;k<j+i;k++) {
int x=a[k],y=a[k+i];
if (opt==1) a[k]=x+y,a[k+i]=x-y;
else a[k]=mul(add(x,y),inv2),a[k+i]=mul(dec(x,y),inv2);
}
}
signed main(){
scanf("%d",&n);
for (int i=1;i<=n;i++) {
scanf("%d",&x); f[x]++;
}
Fwt(f,1);
fac[0]=1; for (int i=1;i<(1<<M);i++) fac[i]=mul(fac[i-1],3);
for (int i=0;i<(1<<M);i++) {
int x=(n+f[i])/2,y=n-x;
f[i]=fac[x]; if (y&1) f[i]=(mod-f[i])%mod;
}
inv2=ksm(2,mod-2);
Fwt(f,-1);
printf("%d\n",dec(f[0],1));
}