题目链接:
HDU 6057
题意:
给你
A[0..2 m −1]
和
B[0..2 m −1]
。
先让你求
C[0..2 m −1]
。
C[0..2 m −1]
的运算方式为:
C[k]=∑ i and j=k A[i xor j]∗B[i or j]
。
看到卷积里面有位运算,就应该先想到
FWT
。
最后让你输出
∑ 2 m −1 i=0 C[i]∗1526 i mod 998244353
。
其中,
m<=19,0≤A[i],B[i]<998244353
。
题解:
官方题解。
FWT
。快速沃尔什变换。
看到卷积里面有位运算,就应该先想到
FWT
。
对于任意两个数
a,b
,我们有:
a and b=(a or b)−(a xor b)。
我们枚举
x=a or b
,
y=a xor b
,显然要求的条件是
x and y=y
。
然后我们可以计算满足以上条件的 (a,b) 有 2 bit(y) 个。
于是可以将 C[k]=∑ i and j=k A[i xor j]∗B[i or j] 重新写成:
C[k]=∑ x ∑ y [x and y=y]∗[x−y=k]∗A[y]∗B[x]∗2 bit(y)
C[k]=∑ x ∑ y [x and y=y]∗[x xor y=k]∗A[y]∗B[x]∗2 bit(y)
C[k]=∑ x xor y=k [x and y=y]∗B[x]∗A[y]∗2 bit(y)
C[k]=∑ x xor y=k [bit(x)−bit(y)=bit(k)]∗B[x]∗A[y]∗2 bit(y)
用 FWT 计算即可。
时间复杂度: O(2 m ∗m 2 ) .
AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=(1<<20);
int q_mod(int a, int b)
{
int res = 1;
while(b)
{
if(b&1)res=1LL*res*a%mod;
b>>=1;
a=1LL*a*a%mod;
}
return res;
}
int A[N],B[N],C[N];
int n,m;
int bit[N];
int aa[N][21];
int bb[N][21];
int cc[N][21];
ll ds[21];
void FWT(int a[N][21],int flag)
{
for(int i=1;i<n;i<<=1)
{
for(int j=0;j<n;j+=(i<<1))
{
for(int k=0;k<=i-1;k++)
{
for(int l=0;l<=m;l++)
{
int x=a[j+k][l],y=a[i+j+k][l];
a[j + k][l]=(x + y)%mod;
a[i + j + k][l]=(x - y + mod) % mod;
}
}
}
}
if(flag==1)
{
int p=q_mod((mod+1)>>1,m);
for(int i=0;i<=n-1;i++)
{
for(int j=0;j<=m;j++)
{
a[i][j]=1LL*a[i][j]*p%mod;
}
}
}
}
/*
2
1 2 3 4
5 6 7 8
568535691
*/
int main()
{
scanf("%d",&m);
n=1<<m;
for(int i=0;i<=n-1;i++)
{
scanf("%d",&A[i]);
}
for(int i=0;i<=n-1;i++){
scanf("%d",&B[i]);
}
for(int i=0;i<=n-1;i++)
{
bit[i]=bit[i>>1]+(i&1);
}
for(int i=0;i<=n-1;i++)
{
A[i]=1LL*A[i]*(1<<bit[i])%mod;
}
for(int i=0;i<=n-1;i++)
{
aa[i][bit[i]]=A[i];
bb[i][bit[i]]=B[i];
}
FWT(aa,0);
FWT(bb,0);
for(int i=0;i<=n-1;i++)
{
memset(ds,0,sizeof(ds));
for(int j=0;j<=m;j++)
{
for(int k=0;k<=j;k++)
{
ds[j-k]=(ds[j-k]+1LL*bb[i][j]*aa[i][k]);
if(ds[j-k]>=1LL<<63)
{
ds[j-k]%=mod;
}
}
}
for(int j=0;j<=m;j++) cc[i][j]=ds[j]%mod;
}
FWT(cc,1);
for(int i=0;i<=n-1;i++){
C[i]=cc[i][bit[i]];
}
int ans=0;
int base=1;
for(int i=0;i<=n-1;i++){
ans=(ans+1LL*C[i]*base)%mod;
base=base*1LL*1526LL%mod;
}
printf("%d\n",ans);
return 0;
}