Kanade’s convolution,HDU - 6057
https://vjudge.net/problem/HDU-6057/origin
Give you two arrays A[0…2m−1] and B[0…2m−1].
Please calculate array C[0…2m−1]:
C[k]=∑i and j=kA[i xor j]∗B[i or j]
You just need to print ∑2m−1i=0C[i]∗1526i mod 998244353
m<=19
0≤A[i],B[i]<998244353
思路: 令x = i ^ j,y = i | j
i | 0 | 0 | 1 | 1 |
---|---|---|---|---|
j | 0 | 1 | 0 | 1 |
x | 0 | 1 | 1 | 0 |
y | 0 | 1 | 1 | 1 |
i&j | 0 | 0 | 0 | 1 |
见上表可知i & j = y-x
故原式
可以改写为
C
[
k
]
=
∑
y
−
x
=
k
A
[
x
]
∗
B
[
y
]
C[k] = \sum_{y-x=k} A[x] *B[y]
C[k]=∑y−x=kA[x]∗B[y]
但实际上由枚举i,j变为枚举x,y时多了一些限制,见上表可以发现当(i,j) = (0,1)或(1,0)时(x,y)都等于(1,1),即当x的某一位为1时,实际上有两组(i,j)可以得到,设x中1的位数为bit(x),同样(x,y)中没有(1,0)的情况,即x & y = x,且y - x = x^y
故实际上是求
C
[
k
]
=
∑
x
⨁
y
=
k
[
x
&
y
=
=
x
]
∗
A
[
x
]
∗
B
[
y
]
∗
2
b
i
t
(
x
)
C[k] = \sum_{x \bigoplus y = k}[x \& y == x]*A[x]*B[y] *2^{bit(x)}
C[k]=∑x⨁y=k[x&y==x]∗A[x]∗B[y]∗2bit(x)
由上表可知[x & y == x]这一限制条件实际上与bit(y) - bit(x) = bit(k)等价,故改为求
C
[
k
]
=
∑
x
⨁
y
=
k
[
b
i
t
(
y
)
−
b
i
t
(
x
)
=
=
b
i
t
(
k
)
]
∗
A
[
x
]
∗
B
[
y
]
∗
2
b
i
t
(
x
)
C[k] = \sum_{x \bigoplus y = k}[bit(y) - bit(x) == bit(k)]*A[x]*B[y] *2^{bit(x)}
C[k]=∑x⨁y=k[bit(y)−bit(x)==bit(k)]∗A[x]∗B[y]∗2bit(x)
可以通过枚举bit(y),bit(x),算出对应的值再累加起来即可
#include<bits/stdc++.h>
#define ll long long
#define MOD 998244353
#define T 1526
using namespace std;
inline int qpow(int a,int b,int mod)
{
int ans =1;
while(b)
{
if(b & 1) ans = 1ll * ans * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
int len,inv2;
void FWT_xor(int *A,int on)//非模意义下不取模即可,但要注意开long long(即使最终结果在int范围内)
{
for(int i=1;i<len;i<<=1) for(int p=i<<1,j=0;j<len;j+=p)
for(int k=0;k<i;++k)
{
int x=A[j+k],y=A[i+j+k];
A[j+k]=(x+y)%MOD,A[i+j+k]=(x+MOD-y)%MOD;
if(on==-1) A[j+k]=1ll*A[j+k]*inv2%MOD,A[i+j+k]=1ll*A[i+j+k]*inv2%MOD;//这里是在mol意义下的,否则把inv2改为/2
}
}
const int MAXN = (1<<19)+5;
int a[MAXN],b[MAXN],ta[20][MAXN],tb[20][MAXN],bit[MAXN],ans[MAXN],tmp[20][MAXN];
int m;
int main()
{
inv2 = qpow(2,MOD-2,MOD);
for(int i = 0;i < MAXN-5;++i)
{
int cnt = 0;
for(int j = 0;j < 19;++j)
if(i & (1<<j))
++cnt;
bit[i] = cnt;
}
while(~scanf("%d",&m))
{
memset(ans,0,sizeof(int) * (1<<m));
for(int i = 0;i <= m;++i)
memset(tmp[i],0,sizeof(int) * (1<<m));
len = (1<<m);
for(int i = 0;i < len;++i)
scanf("%d",&a[i]);
for(int i = 0;i < len;++i)
scanf("%d",&b[i]);
for(int j = 0;j <= m;++j)
{
for(int i = 0;i < len;++i)
{
if(bit[i] == j)
ta[j][i] = 1ll*a[i]*(1<<j)%MOD;
else
ta[j][i] = 0;
if(bit[i] == j)
tb[j][i] = b[i];
else
tb[j][i] = 0;
}
FWT_xor(ta[j],1),FWT_xor(tb[j],1);
}
for(int x = 0;x <= m;++x)
for(int y = x;y <= m;++y)
{
for(int i = 0;i < len;++i)
tmp[y-x][i] = (tmp[y-x][i]+1ll * ta[x][i] * tb[y][i]%MOD)%MOD;
}
for(int i = 0;i <= m;++i)
FWT_xor(tmp[i],-1);
ll res = 0;
ll tt = 1;
for(int i = 0;i < len;++i)
res = (res + 1ll * tmp[bit[i]][i] *tt%MOD)%MOD,tt = tt * T % MOD;
printf("%lld\n",res);
}
return 0;
}