ci=∑j⊗k=iajbk=∑0≤k<2n∑0≤j<2n[j⊗k=i]ajbk
c
i
=
∑
j
⊗
k
=
i
a
j
b
k
=
∑
0
≤
k
<
2
n
∑
0
≤
j
<
2
n
[
j
⊗
k
=
i
]
a
j
b
k
(⊗∈{⊕,∨,∧})
(
⊗
∈
{
⊕
,
∨
,
∧
}
)
FWT F W T 可以用于解决上述的式子,其中 ⊗ ⊗ 可以为二进制的逻辑运算如:异或、与、或。
集合并卷积(或卷积)
ci=∑j∨k=iajbk=∑0≤k<2n∑0≤j<2n[j∨k=i]ajbk
c
i
=
∑
j
∨
k
=
i
a
j
b
k
=
∑
0
≤
k
<
2
n
∑
0
≤
j
<
2
n
[
j
∨
k
=
i
]
a
j
b
k
定义:
ĉ i=∑0≤j≤icj
c
^
i
=
∑
0
≤
j
≤
i
c
j
所以:
ĉ i=∑0≤j<2n∑0≤k<2n[j∨k≤i]ajbk
c
^
i
=
∑
0
≤
j
<
2
n
∑
0
≤
k
<
2
n
[
j
∨
k
≤
i
]
a
j
b
k
ĉ i=∑0≤j≤i∑0≤k≤iajbk=(∑0≤j≤iaj)(∑0≤k≤ibk)=â jb̂ k
c
^
i
=
∑
0
≤
j
≤
i
∑
0
≤
k
≤
i
a
j
b
k
=
(
∑
0
≤
j
≤
i
a
j
)
(
∑
0
≤
k
≤
i
b
k
)
=
a
^
j
b
^
k
集合交卷积(与卷积)
ci=∑j∧k=iajbk=∑0≤k<2n∑0≤j<2n[j∧k=i]ajbk
c
i
=
∑
j
∧
k
=
i
a
j
b
k
=
∑
0
≤
k
<
2
n
∑
0
≤
j
<
2
n
[
j
∧
k
=
i
]
a
j
b
k
定义:
ĉ i=∑i≤j<2ncj
c
^
i
=
∑
i
≤
j
<
2
n
c
j
所以:
ĉ i=∑0≤j<2n∑0≤k<2n[j∧k≥i]ajbk
c
^
i
=
∑
0
≤
j
<
2
n
∑
0
≤
k
<
2
n
[
j
∧
k
≥
i
]
a
j
b
k
ĉ i=∑i≤j<2n∑i≤k<2najbk=(∑i≤j<2naj)(∑i≤k<2nbk)=â jb̂ k
c
^
i
=
∑
i
≤
j
<
2
n
∑
i
≤
k
<
2
n
a
j
b
k
=
(
∑
i
≤
j
<
2
n
a
j
)
(
∑
i
≤
k
<
2
n
b
k
)
=
a
^
j
b
^
k
FWT 包含 or,and,xor
#include <cstdio>
#include <cstring>
#define R register
const int Mod = 998244353, p = 2341;
int Pow(R int A, R int K)
{
R int d = 1;
while(K)
{
if(K & 1) d = 1ll * d * A % Mod;
K >>= 1;
A = 1ll * A * A % Mod;
}
return d;
}
int a[1 << 20], b[1 << 20], c[1 << 20];
int main()
{
R int n;
scanf("%d", &n);
for(R int i = 0; i < (1 << n); i++) scanf("%d", &a[i]);
for(R int i = 0; i < (1 << n); i++) scanf("%d", &b[i]);
R int t; scanf("%d", &t);
if(t == 1)
{
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(!(j >> i & 1))
{
(a[j | 1 << i] += a[j]) %= Mod;
(b[j | 1 << i] += b[j]) %= Mod;
}
for(R int i = 0; i < (1 << n); ++i) c[i] = 1ll * a[i] * b[i] % Mod;
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(!(j >> i & 1))
(c[j | 1 << i] -= c[j] - Mod) %= Mod;
/* 集合交卷积
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(j >> i & 1)
f[j ^ 1 << i] += f[j];
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(j >> i & 1)
g[j ^ 1 << i] += g[j];
for(R int i = 0; i < (1 << n); ++i) h[i] = f[i] * g[i];
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(j >> i & 1)
h[j ^ 1 << i] -= h[j];
*/
}
else
{
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(!(j >> i & 1))
{
R int l = a[j], r = a[j | 1 << i];
a[j] = (l + r) % Mod; a[j | 1 << i] = (Mod + l - r) % Mod;
l = b[j], r = b[j | 1 << i];
b[j] = (l + r) % Mod; b[j | 1 << i] = (Mod + l - r) % Mod;
}
for(R int i = 0; i < (1 << n); i++) c[i] = 1ll * a[i] * b[i] % Mod;
for(R int i = 0; i < n; ++i)
for(R int j = 0; j < (1 << n); ++j) if(!(j >> i & 1))
{
R int l = c[j], r = c[j | 1 << i];
c[j] = (l + r) % Mod; c[j | 1 << i] = (Mod + l - r) % Mod;
}
for(R int j = 0, t = Pow(2, Mod - 1 - n); j < (1 << n); ++j) c[j] = 1ll * c[j] * t % Mod;
}
R int Ans = 0;
for(R int i = 0; i < (1 << n); i++) Ans = (1ll * Ans * p + c[i]) % Mod;
printf("%d", Ans);
return 0;
}
FWT模板,包含or和xor,递归版
#include <cstdio>
#include <cstring>
#define R register
const int Mod = 998244353, p = 2341;
int a[1 << 20], b[1 << 20], c[1 << 20];
void FWTor(R int *A, R int l, R int r)
{
if(r - l <= 1) return ;
R int mid = l + r >> 1;
FWTor(A, l, mid), FWTor(A, mid, r);
for(R int i = l, j = mid; j < r; i++, j++) (A[j] += A[i]) %= Mod;
}
void unFWTor(R int *A, R int l, R int r)
{
if(r - l <= 1) return ;
R int mid = l + r >> 1;
unFWTor(A, l, mid), unFWTor(A, mid, r);
for(R int i = l, j = mid; j < r; i++, j++) (A[j] -= A[i] - Mod) %= Mod;
}
void FWTxor(R int *A, R int l, R int r)
{
if(r - l <= 1) return ;
R int mid = l + r >> 1;
FWTxor(A, l, mid), FWTxor(A, mid, r);
for(R int i = l, j = mid; j < r; i++, j++)
{
R int f = A[i], g = A[j];
A[i] = (f + g) % Mod, A[j] = (f - g) % Mod;
}
}
void unFWTxor(R int *A, R int l, R int r)
{
if(r - l <= 1) return ;
R int mid = l + r >> 1;
unFWTxor(A, l, mid), unFWTxor(A, mid, r);
for(R int i = l, j = mid; j < r; i++, j++)
{
R int f = A[i], g = A[j];
A[i] = 1ll * (f + g) * 499122177 % Mod, A[j] = 1ll * (f - g) * 499122177 % Mod;
}
}
int main()
{
R int n;
scanf("%d", &n);
for(R int i = 0; i < (1 << n); i++) scanf("%d", &a[i]);
for(R int i = 0; i < (1 << n); i++) scanf("%d", &b[i]);
R int t; scanf("%d", &t);
if(t == 1)
{
FWTor(a, 0, 1 << n);
FWTor(b, 0, 1 << n);
for(R int i = 0; i < (1 << n); i++) c[i] = 1ll * a[i] * b[i] % Mod;
unFWTor(c, 0, 1 << n);
}
else
{
FWTxor(a, 0, 1 << n);
FWTxor(b, 0, 1 << n);
for(R int i = 0; i < (1 << n); i++) c[i] = 1ll * a[i] * b[i] % Mod;
unFWTxor(c, 0, 1 << n);
}
for(R int i = 0; i < (1 << n); i++) printf("%d ", c[i]); puts("");
R int Ans = 0;
for(R int i = 0; i < (1 << n); i++) Ans = (1ll * Ans * p + c[i]) % Mod;
printf("%d", Ans);
return 0;
}