说在前面
并没有什么想说的,但是要保持格式=w=
题目
题目大意
给出N个数字,N不超过1e6,数字大小不超过1e6
从中选出两个不相交的集合,使得这两个集合异或和相等,不要求把所有数全部选到,问方案数
输入输出格式
输入格式:
第一行一个整数N,含义如题
接下来一行N个整数,描述这一堆数字
输出格式:
输出方案数在模998244353意义下的值
解法
首先可以想到一个最暴力的dp,就是定义dp[i][j][k]表示,前i个数里面,一个集合异或和为j,另一个集合异或和为k的方案数。转移是显然的,然而复杂度高的令人窒息…
然后可以发现,最后两维并没有什么卵用。因为,如果能选出一个子集,这个子集的异或和为0,那么把这个子集分成两部分,这两部分的异或和一定相等。于是可以把状态减小一维:定义dp[i][j]表示,前i个数里面,选出两个不相交集合异或和为j的方案数。转移是显然的: dp[i][j]=dp[i−1][j]+dp[i−1][j⊕a[i]]∗2 d p [ i ] [ j ] = d p [ i − 1 ] [ j ] + d p [ i − 1 ] [ j ⊕ a [ i ] ] ∗ 2 (乘2是因为有两个集合),最后答案就是dp[N][0]。然后这个复杂度仍然过不去
然后发现,这个转移是一个异或卷积的形式,与之相卷的数组假设是b[],那么b[0] = 1,b[ a[i] ] = 2。那么每一层转移,相当于卷上这么一个数组b(每一层转移对应一个b,这些b不一定相同)。
观察一下这个b的特点,发现b的正变换要么是1,要么是-3(因为0位置会对所有位置贡献1,而a[i]位置会对所有位置贡献2或者-2。不理解的,去复习异或卷积的正变换)。
那么如果能快速的知道,这一位在所有转移中乘上了多少个-1,多少个3,就可以直接快速幂,然后逆变换得出答案了
假设我们把所有正变换之后的数组加起来,某一位上的数字是S,那么可以列出方程: 3x−y=S 3 x − y = S 且 x+y=N x + y = N ,解出来就好了。那么,所有正变换的数组加起来,实际上就是所有数组加起来的正变换(因为FWT可加)
所以这道题就被解决了
下面是自带大常数的代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
const int P = 998244353 , inv2 = 499122177 ;
int N , remN , cnt[2000005] , maxv , mi3[1000005] ;
void FWT( int *a , int lf , int rg , bool inver ){
int siz = ( rg - lf + 1 ) / 2 , t0 , t1 ;
if( !siz ) return ;
FWT( a , lf , lf + siz - 1 , inver ) ;
FWT( a , rg - siz + 1 , rg , inver ) ;
for( int i = lf + siz - 1 ; i >= lf ; i -- ){
t0 = a[i] , t1 = a[i+siz] ;
if( inver ){
a[i] = 1LL * ( t0 + t1 ) * inv2 %P ;
a[i+siz] = 1LL * ( t0 - t1 ) * inv2 %P ;
} else a[i] = ( t0 + t1 )%P , a[i+siz] = ( t0 - t1 )%P ;
}
}
void solve(){
for( N = 1 ; N <= maxv ; N <<= 1 ) ;
FWT( cnt , 0 , N - 1 , 0 ) ;
mi3[0] = 1 ;
for( int i = 1 ; i <= remN ; i ++ ) mi3[i] = 3LL * mi3[i-1] %P ;
for( int i = 0 , x ; i < N ; i ++ ){
// 3x - y = cnt[i] && x + y = remN
// 4x - remN = cnt[i] ===> x = ( cnt[i] + remN ) / 4 ;
x = ( cnt[i] + remN ) / 4 ;
cnt[i] = (remN-x)&1 ? -mi3[x] : mi3[x] ;
} FWT( cnt , 0 , N - 1 , 1 ) ;
printf( "%d" , ( cnt[0] + P - 1 )%P ) ;
}
int main(){
scanf( "%d" , &N ) , remN = N ;
for( int i = 0 , x ; i < N ; i ++ ){
scanf( "%d" , &x ) ;
cnt[x] += 2 ; maxv = max( maxv , x ) ;
} cnt[0] += N ; solve() ;
}