题目大意
给你两个长度为 2 n − 1 2^n-1 2n−1的数组 A A A和 B B B,计算长度为 2 n − 1 2^n-1 2n−1的数组 C C C
C [ k ] = ∑ i a n d j = k A [ i x o r j ] × B [ i o r j ] C[k]=\sum\limits_{i \ and \ j =k}A[i \ xor \ j]\times B[i \ or \ j] C[k]=i and j=k∑A[i xor j]×B[i or j]
你需要输出 ∑ i = 0 2 m − 1 C [ i ] × 152 6 i \sum\limits_{i=0}^{2^m-1}C[i]\times 1526^i i=0∑2m−1C[i]×1526i对 998244353 998244353 998244353取模后的值。
1 ≤ m ≤ 19 , 0 ≤ A [ i ] , B [ i ] < 998244353 1\leq m\leq 19,0\leq A[i],B[i]<998244353 1≤m≤19,0≤A[i],B[i]<998244353
题解
令 x = i ⊕ j , y = i ∣ j x=i\oplus j,y=i|j x=i⊕j,y=i∣j,因为 ( i ∣ j ) − ( i & j ) = i ⊕ j (i|j)-(i\& j)=i\oplus j (i∣j)−(i&j)=i⊕j,所以 i & j = ( i ∣ j ) − ( i ⊕ j ) i\&j=(i|j)-(i\oplus j) i&j=(i∣j)−(i⊕j),则
C k = ∑ y − x = k , x & k = 0 2 d x × A x × B y C_k=\sum\limits_{y-x=k,x\& k=0}2^{d_x}\times A_x\times B_y Ck=y−x=k,x&k=0∑2dx×Ax×By
其中 d i d_i di表示 i i i的二进制位中为 1 1 1的位数。
因为对于一对数对 ( x , y ) (x,y) (x,y), x x x的每一个为 1 1 1的位, i , j i,j i,j都有两种选法( i i i这一位为 1 1 1且 j j j这一位为 0 0 0、 i i i这一位为 0 0 0且 j j j这一位为 1 1 1)使 x = i ⊕ j , y = i ∣ j x=i\oplus j,y=i|j x=i⊕j,y=i∣j。所以有 2 d x 2^{d_x} 2dx对 ( i , j ) (i,j) (i,j)满足上述条件,要乘上 2 d x 2^{d_x} 2dx。
因为 x = i ⊕ j , y = i ∣ j x=i\oplus j,y=i|j x=i⊕j,y=i∣j,所以 x & y = y x\&y=y x&y=y,可将 y − x = k y-x=k y−x=k换成 y ⊕ x = k y\oplus x=k y⊕x=k。 x & k = 0 x\& k=0 x&k=0也可表示为 d x + d k = d y d_x+d_k=d_y dx+dk=dy。再令 A i = A i × 2 d i A_i=A_i\times 2^{d_i} Ai=Ai×2di,则
C k = ∑ y ⊕ x = k , d x + d k = d y A x × B y C_k=\sum\limits_{y\oplus x=k,d_x+d_k=d_y}A_x\times B_y Ck=y⊕x=k,dx+dk=dy∑Ax×By
多项式 V A i V\!A_i VAi表示二进制位为 1 1 1的个数为 i i i的 A A A值,多项式 V B i V\!B_i VBi表示二进制位为 1 1 1的个数为 i i i的 B B B值,二进制位为 1 1 1的个数不等于 i i i的数在 V A i , V B i V\!A_i,V\!B_i VAi,VBi中值为 0 0 0。
对于每一个 i i i,对 V A i V\!A_i VAi和 V B i V\!B_i VBi做一次异或卷积的 F W T FWT FWT。令 V k = ∑ j = 0 i ( V A j ) k × ( V B i − j ) k V_k=\sum\limits_{j=0}^i(V\!A_j)_k\times (V\!B_{i-j})_k Vk=j=0∑i(VAj)k×(VBi−j)k,将 V k V_k Vk逆变换回来。若 k k k中二进制位为 1 1 1的个数为 i i i,则 C k = V k C_k=V_k Ck=Vk。
最后,对于每个 C i C_i Ci都乘上 152 6 i 1526^i 1526i即可。
时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
code
#include<iostream>
#include<cstdio>
using namespace std;
int n,cnt[1<<20];
long long x,ans=0,now=1,ta[22][1<<20],tb[22][1<<20],c[1<<20],v[1<<20];
const int mod=998244353,ny2=499122177;
void fwt(long long *w,int fl){
for(int s=2;s<=1<<n;s<<=1){
int mid=s>>1;
for(int v=0;v<1<<n;v+=s){
for(int i=0;i<mid;i++){
int t1=w[v+i],t2=w[v+mid+i];
w[v+i]=(t1+t2)%mod;w[v+mid+i]=(t1-t2+mod)%mod;
if(fl==-1){
w[v+i]=1ll*w[v+i]*ny2%mod;
w[v+mid+i]=1ll*w[v+mid+i]*ny2%mod;
}
}
}
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<1<<n;i++) cnt[i]=cnt[i-(i&(-i))]+1;
for(int i=0;i<1<<n;i++){
scanf("%lld",&x);
x=x*(1ll<<cnt[i])%mod;
ta[cnt[i]][i]=x;
}
for(int i=0;i<1<<n;i++){
scanf("%lld",&x);
tb[cnt[i]][i]=x;
}
for(int i=0;i<=n;i++){
fwt(ta[i],1);fwt(tb[i],1);
}
for(int i=0;i<=n;i++){
for(int j=0;i+j<=n;j++){
for(int k=0;k<1<<n;k++){
c[k]=(c[k]+1ll*ta[j][k]*tb[i+j][k]%mod)%mod;
}
}
fwt(c,-1);
for(int j=0;j<1<<n;j++){
if(cnt[j]==i) v[j]=c[j];
c[j]=0;
}
}
for(int i=0;i<1<<n;i++){
ans=(ans+1ll*v[i]*now%mod)%mod;
now=now*1526ll%mod;
}
printf("%lld\n",ans);
return 0;
}