Description
Input
第一行一整数m,之后输入序列A和B(m<=19,0<=A[i],B[i]<998244353)
Output
输出答案
Sample Input
2
1 2 3 4
5 6 7 8
Sample Output
568535691
Solution
Code
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int mod=998244353,inv2=499122177,maxn=1<<20;
int m,n,A[maxn],B[maxn],C[maxn],bit[maxn],a[maxn][21],b[maxn][21],c[maxn][21];
void read(int &x)
{
x=0;
char p=getchar();
while(!(p<='9'&&p>='0'))p=getchar();
while(p<='9'&&p>='0')x*=10,x+=p-48,p=getchar();
}
int mod_pow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}
return ans;
}
void FWT(int a[maxn][21],int n,int sta)
{
for(int d=1;d<n;d<<=1)
for(int i=0;i<n;i+=(d<<1))
for(int j=0;j<d;j++)
for(int k=0;k<=m;k++)
{
int x=a[i+j][k],y=a[i+j+d][k];
a[i+j][k]=(x+y)%mod,a[i+j+d][k]=(x-y+mod)%mod;
//xor:a[i+j]=x+y,a[i+j+d]=(x-y+mod)%mod;
//and:a[i+j]=x+y;
//or:a[i+j+d]=x+y;
}
if(sta==1)
{
int inv=mod_pow(inv2,m);
for(int i=0;i<n;i++)
for(int j=0;j<=m;j++)
a[i][j]=(ll)a[i][j]*inv%mod;
}
}
ull temp[21];
int main()
{
read(m);
n=1<<m;
for(int i=0;i<n;i++)read(A[i]);
for(int i=0;i<n;i++)read(B[i]);
for(int i=0;i<n;i++)bit[i]=bit[i>>1]+(i&1);
for(int i=0;i<n;i++)A[i]=(ll)A[i]*(1<<bit[i])%mod;
for(int i=0;i<n;i++)a[i][bit[i]]=A[i],b[i][bit[i]]=B[i];
FWT(a,n,0),FWT(b,n,0);
for(int i=0;i<n;i++)
{
memset(temp,0,sizeof(temp));
for(int j=0;j<=m;j++)
for(int k=0;k<=j;k++)
{
//c[i][j-k]=(c[i][j-k]+(ll)a[i][k]*b[i][j])%mod;
temp[j-k]+=(ll)a[i][k]*b[i][j];
if(temp[j-k]>=(1ll<<63))temp[j-k]%=mod;
}
for(int j=0;j<=m;j++)c[i][j]=temp[j]%mod;
}
FWT(c,n,1);
for(int i=0;i<n;i++)C[i]=c[i][bit[i]];
int ans=0,p=1;
for(int i=0;i<n;i++)
{
ans=(ans+(ll)C[i]*p)%mod;
p=1526ll*p%mod;
}
printf("%d\n",ans);
return 0;
}