这是若干个
2xai+1
的东西的卷积
然后这个FWT一下发现每一项只有
−1
或
3
那么卷积的FWT每一项就是若干个
这个不好求
直接加在一起FWT,那么我们得到了每一项
因为只有这两个取值,可以直接解方程有多少个
然后快速幂再乘起来
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;
typedef long long ll;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc());
}
const int P=998244353;
const int INV2=(P+1)>>1;
const int N=1048576;
//const int N=8;
int a[N];
inline void FWT(int *a,int n,int r){
for (int i=1;i<n;i<<=1)
for (int j=0;j<n;j+=(i<<1))
for (int k=0;k<i;k++){
int x=a[j+k],y=a[j+k+i];
if (r) a[j+k]=(x+y)%P,a[j+k+i]=(x+P-y)%P;
else a[j+k]=(ll)(x+y)*INV2%P,a[j+k+i]=(ll)(x+P-y)*INV2%P;
}
}
int pw[N];
inline ll Pow(ll a,int b){
ll ret=1;
for (;b;b>>=1,a=a*a%P)
if (b&1)
ret=ret*a%P;
return ret;
}
int main(){
int n,x;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(n);
pw[0]=1;
for (int i=1;i<=n;i++) read(x),a[0]++,a[x]+=2,pw[i]=pw[i-1]*3LL%P;
FWT(a,N,1);
for (int i=0;i<N;i++){
x=(ll)(3*n+P-a[i])*INV2%P*INV2%P;
a[i]=(x&1)?(P-pw[n-x])%P:pw[n-x];
}
FWT(a,N,0);
printf("%d\n",(a[0]+P-1)%P);
return 0;
}
还有阿爷的分层FWT,具体我忘了,还是那句,退役了
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;
typedef long long ll;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc());
}
const int P=998244353;
const int INV2=(P+1)>>1;
const int N=1048576;
//const int N=8;
inline void FWT(int *a,int n){
for (int i=1;i<n;i<<=1)
for (int j=0;j<n;j+=(i<<1))
for (int k=0;k<i;k++){
int x=a[j+k],y=a[j+k+i];
a[j+k]=(ll)(x+y)*INV2%P,a[j+k+i]=(ll)(x+P-y)*INV2%P;
}
}
int c[N];
int n;
int f[2][N];
int g[2][N];
inline void Add(int &x,int y){
x+=y; while (x>=P) x-=P;
}
int main(){
int x;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(n);
for (int i=1;i<=n;i++) read(x),c[x]++;
for (int i=0;i<N;i++){
ll f[2]={1,0},g[2];
for (int j=1;j<=c[i];j++){
g[0]=(f[0]+f[1]*2LL)%P,g[1]=(f[1]+f[0]*2LL)%P;
f[0]=g[0],f[1]=g[1];
}
::f[0][i]=f[0],::f[1][i]=f[1];
}
for (int i=1;i<N;i<<=1){
for (int t=0;t<N;t++) g[0][t]=g[1][t]=0;
for (int j=0;j<N;j+=(i<<1)){
for (int x=0;x<2;x++)
for (int y=0;y<2;y++){
int *g=::g[x^y],*f1=f[x],*f2=f[y];
for (int k=0;k<i;k++){
int tem=(ll)f1[j+k]*f2[j+i+k]%P;
Add(g[j+k],tem);
if (y==0) Add(g[j+i+k],tem);
else Add(g[j+i+k],P-tem);
}
}
}
for (int t=0;t<N;t++) f[0][t]=g[0][t],f[1][t]=g[1][t];
}
FWT(f[0],N);
FWT(f[1],N);
printf("%d\n",(f[0][0]+f[1][0]-1)%P);
return 0;
}