Kanade’s convolution
题解
多简单的一道题呀
我们可以考虑子集求和。毕竟这种又是位运算又是相乘的很容易想到FWT。
由于
(
x
x
o
r
y
)
+
(
x
a
n
d
y
)
=
(
x
o
r
y
)
(x\,\,xor\,\,y)+(x\,\,and\,\,y)=(x\,\,or\,\,y)
(xxory)+(xandy)=(xory),所以我们我们可以将枚举的
i
,
j
i,j
i,j改为
x
,
y
x,y
x,y使得
x
=
i
o
r
j
,
y
=
i
x
o
r
j
x=i\,\,or\,\,j,y=i\,\,xor\,\,j
x=iorj,y=ixorj。
但是由于我们枚举的
x
,
y
x,y
x,y的出现次数是不一样的,所以还要针对的
y
y
y加上一个
2
b
i
t
(
y
)
2^{bit(y)}
2bit(y)的系数。
原式变成了
C
x
−
y
=
∑
x
a
n
d
y
=
y
2
b
i
t
(
y
)
A
y
B
x
C_{x-y}=\sum_{x\,\,and\,\,y=y}2^{bit(y)}A_{y}B_{x}
Cx−y=∑xandy=y2bit(y)AyBx。
上式明显可以再把
[
x
a
n
d
y
=
y
]
[x\,\,and\,\,y=y]
[xandy=y]这个条件去掉,根据二进制位来求解,用
k
k
k来代替
x
−
y
x-y
x−y,因为这一位是需要我们枚举的,可化为
C
k
=
∑
b
i
t
(
x
)
−
b
i
t
(
y
)
=
b
i
t
(
k
)
2
b
i
t
(
y
)
A
y
B
x
C_{k}=\sum_{bit(x)-bit(y)=bit(k)}2^{bit(y)}A_{y}B_{x}
Ck=∑bit(x)−bit(y)=bit(k)2bit(y)AyBx。这样就可以卷积了。
我们可以先将
2
b
i
t
(
y
)
2^{bit(y)}
2bit(y)乘入
A
y
A_{y}
Ay中,因为它只与
y
y
y的大小有关。
我们可以采用子集卷积的方法,将
A
,
B
,
C
A,B,C
A,B,C三个多项式根据它们二进制下含1的个数分类,用
F
(
A
,
i
)
F(A,i)
F(A,i)表示
A
A
A中二进制的含1个数为
i
i
i的部分。
有,
F
(
A
,
i
)
=
∑
j
=
0
i
F
(
B
,
j
)
F
(
A
,
i
−
j
)
F(A,i)=\sum_{j=0}^{i}F(B,j)F(A,i-j)
F(A,i)=∑j=0iF(B,j)F(A,i−j)。
这明显是可以根据FWT处理的,FWT后相乘再逆回来即可。
时间复杂度为
O
(
n
l
o
g
2
n
)
O\left(nlog^2\,n\right)
O(nlog2n)。
PS:应该没有人跟我最开始想的一样,先通过二元方程进行转化后跑多项式乘法,最后再解回来的吧。
源码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN (1<<19)+5
#define reg register
typedef long long LL;
const int mo=998244353;
const int inv2=499122177;
template<typename _T>
inline void read(_T &x){
_T f=1;x=0;char s=getchar();
while('0'>s||'9'<s){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
int add(const int x,const int y){return x+y<mo?x+y:x+y-mo;}
void FWT(int *A,const int lim,const int typ){
for(reg int k=1;k<lim;k<<=1)
for(reg int i=0;i<lim;i+=(k<<1))
for(reg int j=i;j<i+k;++j){
A[j]=add(A[j],A[j+k]);
A[j+k]=add(A[j],add(mo-A[j+k],mo-A[j+k]));
A[j]=typ^1?1ll*A[j]*inv2%mo:A[j];
A[j+k]=typ^1?1ll*A[j+k]*inv2%mo:A[j+k];
}
}
int a[MAXN],b[MAXN],m,lim,bit[MAXN],F[20][MAXN],G[20][MAXN],pw[MAXN],ans[20][MAXN],sum;
signed main(){
read(m);lim=(1<<m);pw[0]=1;
for(int i=1;i<=m;i++)pw[i]=add(pw[i-1],pw[i-1]);
for(int i=1;i<lim;i++)bit[i]=bit[i^(i&-i)]+1;
for(int i=0;i<lim;i++)read(a[i]),F[bit[i]][i]=1ll*pw[bit[i]]*a[i]%mo;
for(int i=0;i<lim;i++)read(b[i]),G[bit[i]][i]=b[i];
for(int i=0;i<=m;i++)FWT(F[i],lim,1),FWT(G[i],lim,1);
for(int i=0;i<=m;i++)
for(int j=0;j<=m-i;j++)
for(int k=0;k<lim;k++)
ans[i][k]=(1ll*ans[i][k]+1ll*F[j][k]*G[i+j][k])%mo;
for(int i=0;i<=m;i++)FWT(ans[i],lim,-1);int now=1;
for(int i=0;i<lim;i++)sum=(1ll*sum+1ll*ans[bit[i]][i]*now)%mo,now=1526ll*now%mo;
printf("%d\n",sum);
return 0;
}