第一次学会多项式的题目。
题意:
n
n
n个数的多重集
S
S
S,设
S
′
⊆
S
,
f
(
A
)
=
[
∣
A
∣
=
t
m
,
t
∈
Z
]
(
X
O
R
a
i
∈
a
a
i
)
k
S' \subseteq S,f(A)=[|A| =tm,t\in Z](XOR_{a_i\in a}ai)^k
S′⊆S,f(A)=[∣A∣=tm,t∈Z](XORai∈aai)k,求
∑
S
′
⊆
S
f
(
S
′
)
\sum_{S'\subseteq S} f(S')
∑S′⊆Sf(S′)。
我们统计出每个异或值出现多少次,记为 f ( x ) , x ∈ [ 0 , 2 20 − 1 ] f(x),x\in[0,2^{20}-1] f(x),x∈[0,220−1]。
用多项式来表示,有 f ( a ) = [ x a y 0 ] ∏ i = 1 n ( 1 + x a i y ) f(a)=[x^ay^0]\prod_{i=1}^n(1+x^{a_i}y) f(a)=[xay0]∏i=1n(1+xaiy),我们对 x x x做异或卷积,对 y y y做模m的卷积。
不妨把y看成常数,满足 y m = 0 ( m o d 998244353 ) y^m=0 (\mod 998244353) ym=0(mod998244353),然后对每个 ( 1 + y x a i ) (1+yx^{a_i}) (1+yxai)进行FWT,然后点乘起来,最后再求一次逆,但是这样是 O ( n 2 ) O(n^2) O(n2),现在思考如何化简。
我们知道多项式 g ( x ) = 1 g(x)=1 g(x)=1的FWT是全1, h ( x ) = x t h(x)=x^t h(x)=xt的FWT每一项要么是1要么是-1,如果我们知道在求完FWT第 i i i位有多少个多项式为1,设有 c i c_i ci个,那么这一项就是 [ y 0 ] ( 1 + y ) c i ( 1 − y ) n − c i [y^0](1+y)^{c_i}(1-y)^{n-c_i} [y0](1+y)ci(1−y)n−ci,这要预处理一下就能求了。
现在考虑怎么求每一项有多少个1,我们直接把所有的数放进要进行FWT的数组,即考虑多项式 ∑ i = 1 n x a i \sum_{i=1}^n x^{a_i} ∑i=1nxai,设有x个1,y个-1,那么FWT以后的答案是 x − y x-y x−y,又知道 x + y = n x+y=n x+y=n,所以可以解出变量。
#include<bits/stdc++.h>
#define rep(i,x,y) for(int i=x;i<=y;i++)
#define dwn(i,x,y) for(int i=x;i>=y;i--)
#define ll long long
using namespace std;
template<typename T>inline void qr(T &x){
x=0;int f=0;char s=getchar();
while(!isdigit(s))f|=s=='-',s=getchar();
while(isdigit(s))x=x*10+s-48,s=getchar();
x=f?-x:x;
}
int cc=0,buf[31];
template<typename T>inline void qw(T x){
if(x<0)putchar('-'),x=-x;
do{buf[++cc]=int(x%10);x/=10;}while(x);
while(cc)putchar(buf[cc--]+'0');
}
const int N=2e5+10,mod=998244353;
int power(int a,int b){
int ret=1;
while(b){
if(b&1)ret=1ll*ret*a%mod;
a=1ll*a*a%mod;b>>=1;
}
return ret;
}
int n,m,k;
int a[N];ll A[(1<<20)+10];
int s1[N][110],s2[N][110];
int f[N];
void XOR(ll *f,int type=1){
int n=1<<20;
for(int o=2,k=1;o<=n;o<<=1,k<<=1)
for(int i=0;i<n;i+=o)
for(int j=0;j<k;j++){
ll x=f[i+j];
ll y=f[i+j+k];
if(type==1){
f[i+j]=x+y;
f[i+j+k]=x-y;
}
else{
f[i+j]=(x+y)%mod;
f[i+j+k]=((x-y)%mod+mod)%mod;
f[i+j]=1ll*f[i+j]*type%mod;
f[i+j+k]=1ll*f[i+j+k]*type%mod;
}
}
}
void solve(){
qr(n),qr(m),qr(k);
rep(i,1,n){
qr(a[i]);
A[a[i]]++;
}
s1[0][0]=1;
s2[0][0]=1;
rep(i,1,n){
rep(j,0,m-1){
int k=j?j-1:m-1;
s1[i][j]=(s1[i-1][j]+s1[i-1][k])%mod;
s2[i][j]=(s2[i-1][j]-s2[i-1][k]+mod)%mod;
}
}
rep(i,0,n){
rep(j,0,m-1){
int t=j?m-j:0;
f[i]=(f[i]+1ll*s1[i][j]*s2[n-i][t]%mod)%mod;
}
}
XOR(A);
rep(i,0,(1<<20)-1){
// cout<<(A[i]+n)/2<<endl;
A[i]=f[(A[i]+n)/2];
}
XOR(A,power(2,mod-2));
int ans=0;
rep(i,0,(1<<20)-1)
if(A[i]){
// cout<<A[i]<<" ";
ans=(ans+1ll*A[i]*power(i,k)%mod)%mod;
}
qw(ans);puts("");
}
int main(){
int tt;tt=1;
while(tt--)solve();
return 0;
}