说在前面
人老了脑子不好使了hhhhh
之前还想得很明白,去食堂吃了饭回来,连自己之前写的式子都看不懂了。。。
题目
题目大意
给出一个两个长为 2M 2 M 的数组 A[] A [ ] 和 B[] B [ ] ,下标从0开始(M不超过19)
求出数组 C[] C [ ] ,满足 C[k]=∑i and j=kA[i xor j]∗B[i or j] C [ k ] = ∑ i a n d j = k A [ i x o r j ] ∗ B [ i o r j ]
并输出 ∑2m−1i=0C[i]∗1526i mod 998244353 ∑ i = 0 2 m − 1 C [ i ] ∗ 1526 i m o d 998244353 的值
输入输出格式
输入格式:
第一行一个整数M,含义如题
接下来一行
2M
2
M
个整数,表示
A[0..2m−1]
A
[
0..2
m
−
1
]
接下来一行
2M
2
M
个整数,表示
B[0..2m−1]
B
[
0..2
m
−
1
]
输出格式:
输出一行一个整数,表示答案
解法
首先观察一下这个式子然后就被吓到了,发现数组的下标里面有运算符号, 不好处理,于是我们用其它的东西来替换之
令:
i xor j=x
i
x
o
r
j
=
x
,
i or j=y
i
o
r
j
=
y
,那么原式可以写成:
C[k]=∑y−x=kA[x]∗B[y]
C
[
k
]
=
∑
y
−
x
=
k
A
[
x
]
∗
B
[
y
]
,并且同时要满足
x and y=x
x
a
n
d
y
=
x
因为
y
y
包含了,所以减法相当于是异或,并且
x
x
与的二进制位之差,必须是
k
k
的二进制位。
所以上式可以改写成:
然后发现这就是一个位运算卷积的形式了,如果去掉
[bit(y)−bit(x)=bit(k)]
[
b
i
t
(
y
)
−
b
i
t
(
x
)
=
b
i
t
(
k
)
]
,这就是一个简单的异或卷积
于是我们按照一个数字里二进制位上1的个数,把A,B分别拆分成M个序列。对这M个序列分别做异或正变换。为了满足
bit(y)−bit(x)=bit(k)
b
i
t
(
y
)
−
b
i
t
(
x
)
=
b
i
t
(
k
)
,将
A[x][]
A
[
x
]
[
]
和
B[y][]
B
[
y
]
[
]
所乘,累加到
C[k][]
C
[
k
]
[
]
里面,然后再做一遍异或逆变换,得到的
C[bit(k)][k]
C
[
b
i
t
(
k
)
]
[
k
]
就是所求的
C[k]
C
[
k
]
关于正确性,其实可以这么想:不要按照题目中的来,假设现在就是得到了两个数组A、B。其中,A[]里只有下标二进制1的个数为p的才有值;B[]里只有下标二进制1的个数为q的才有值。对A和B做一次异或卷积,然后得到了数组C[]。但是这个C里面,只满足了异或限制。即,得到的C里,下标二进制1的个数不是p-q的也有值。只是下标二进制1的个数恰好等于p-q的才合法而已,其它的数都不合法,统统扔掉
于是回过头来看这道题,大概也是这样子的
下面是代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
const int P = 998244353 , inv2 = 499122177 ;
int N , ws , a[(1<<20)+5] , b[(1<<20)+5] , c[(1<<20)+5] , popcnt[(1<<20)+5] ;
int A[20][1<<20] , B[20][1<<20] , C[20][1<<20] ;
void FWT( int *a , int lf , int rg , bool rev ){
if( lf == rg ) return ;
int siz = ( rg - lf + 1 ) >> 1 , t0 , t1 ;
FWT( a , lf , lf + siz - 1 , rev ) ;
FWT( a , rg - siz + 1 , rg , rev ) ;
for( int i = lf + siz - 1 ; i >= lf ; i -- ){
t0 = a[i] , t1 = a[i+siz] ;
if( rev ){
a[i] = 1LL * inv2 * ( t0 + t1 ) %P ;
a[i+siz] = 1LL * inv2 * ( t0 - t1 ) %P ;
} else a[i] = ( t0 + t1 )%P , a[i+siz] = ( t0 - t1 )%P ;
}
}
void solve(){
for( int i = 1 ; i < N ; i ++ )
popcnt[i] = popcnt[i-(i&-i)] + 1 ;
for( int i = 0 ; i < N ; i ++ ){
A[ popcnt[i] ][i] = ( 1LL << popcnt[i] ) * a[i] %P ;
B[ popcnt[i] ][i] = b[i] ;
} for( int i = 0 ; i <= ws ; i ++ )
FWT( A[i] , 0 , N - 1 , 0 ) ,
FWT( B[i] , 0 , N - 1 , 0 ) ;
for( int k = 0 ; k <= ws ; k ++ ){
for( int y = k , x = 0 ; y <= ws ; y ++ , x ++ )
for( int i = 0 ; i < N ; i ++ )
C[k][i] = ( C[k][i] + 1LL * A[x][i] * B[y][i] )%P ;
FWT( C[k] , 0 , N - 1 , 1 ) ;
} for( int i = 0 ; i < N ; i ++ ) c[i] = C[ popcnt[i] ][i] ;
long long muls = 1 , ans = 0 ;
for( int i = 0 ; i < N ; i ++ ){
ans = ( ans + c[i] * muls )%P ;
muls = muls * 1526 %P ;
} printf( "%lld\n" , ( ans + P )%P ) ;
}
int main(){
scanf( "%d" , &ws ) ; N = ( 1 << ws ) ;
for( int i = 0 ; i < N ; i ++ ) scanf( "%d" , &a[i] ) ;
for( int i = 0 ; i < N ; i ++ ) scanf( "%d" , &b[i] ) ;
solve() ;
}