题目大意:
有n个数字,要求选出k个不同的数字使得异或和是s,对所有选择方案求gcd并求和。
n
≤
1
0
6
,
a
i
,
s
≤
m
≤
50000
n\le10^6,a_i,s\le m\le50000
n≤106,ai,s≤m≤50000
题解:
首先关于gcd可以容斥成
∑
i
=
1
n
f
(
i
)
ϕ
(
i
)
\sum_{i=1}^n f(i)\phi(i)
∑i=1nf(i)ϕ(i),
f
(
i
)
f(i)
f(i)表示选k个不同的数字使得gcd是i的倍数的方案数。
然后发现这个可以容斥成可以选相同的数字。然后考虑若
i
<
S
i<S
i<S,则直接FWT;否则跑
(
m
i
)
2
\left(\frac{m}{i}\right)^2
(im)2暴力。然后积一波分求复杂度发现取
S
=
m
S=\sqrt{m}
S=m最优。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
namespace INPUT_SPACE{
const int BS=(1<<24)+5;char Buffer[BS],*HD,*TL;inline int gc() { if(HD==TL) TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);return (HD==TL)?EOF:*HD++; }
inline int inn() { int x,ch;while((ch=gc())<'0'||ch>'9');x=ch^'0';while((ch=gc())>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^'0');return x; }
}using INPUT_SPACE::inn;
const int N=66000;
int cnt[N],bsz[N],a[N],lst[N],vis[N],L[N],R[N],sn,inv6,inv24;
inline int clr(int *a,int n) { return memset(a,0,sizeof(int)*n),0; }
inline int gcd(int a,int b) { return a?gcd(b%a,a):b; }
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int solvek1(int s) { return !printf("%lld\n",(lint)cnt[s]*s%mod); }
inline int solvek2(int n,int s) { int ans=0;rep(i,1,n) ans=(ans+(lint)cnt[i]*cnt[i^s]%mod*gcd(i,i^s))%mod;return !printf("%lld\n",(lint)ans*fast_pow(2,mod-2)%mod); }
int p[N],phi[N],np[N];
inline int prelude(int n)
{
phi[1]=1,sn=(int)sqrt(n/7+0.5);
for(int i=2,c=0;i<=n;i++)
{
if(!np[i]) p[++c]=i,phi[i]=i-1;
rep(j,1,c&&p[j]<=n/i)
{
int x=p[j]*i;np[x]=1;
if(i%p[j]) phi[x]=phi[i]*(p[j]-1);
else { phi[x]=phi[i]*p[j];break; }
}
}
int m=1;while(m<=n) m<<=1;rep(i,1,m-1) bsz[i]=bsz[i>>1]^(i&1);
return 0;
}
inline int fwt(int *a,int n)
{
for(int i=2;i<=n;i<<=1) for(int j=0,t=i>>1,x,y;j<n;j+=i) rep(k,0,t-1)
x=a[j+k],y=a[j+k+t],a[j+k]=(x+y>=mod?x+y-mod:x+y),a[j+k+t]=(x-y<0?x-y+mod:x-y);
return 0;
}
inline int ufwt(int *a,int n,int s)
{
lint ans=0;rep(i,0,n-1) if(bsz[i&s]) ans-=a[i];else ans+=a[i];
return ans%=mod,ans+=mod,ans%=mod,int(ans*fast_pow(n,mod-2)%mod);
}
inline int F(int x,int n,int k,int s)
{
int m=1;while(m<=n) m<<=1;
if(x<=sn)
{
clr(a,m);rep(i,1,n/x) a[i*x]=cnt[i*x];fwt(a,m);
rep(i,0,m-1) a[i]=fast_pow(a[i],k);return ufwt(a,m,s);
}
int c=0,ans=0;
for(int i=x;i<=n;i+=x) for(int j=x;j<=n;j+=x)
{
if(!vis[i^j]) vis[lst[++c]=i^j]=1;
L[i^j]=R[i^j]=(L[i^j]+(lint)cnt[i]*cnt[j])%mod;
}
if(k==3) { rep(i,1,c) R[lst[i]]=0;for(int i=x;i<=n;i+=x) R[i]=cnt[i]; }
rep(i,1,c) if(R[lst[i]^s]) ans=(ans+(lint)L[lst[i]]*R[lst[i]^s])%mod;
rep(i,1,c) vis[lst[i]]=L[lst[i]]=R[lst[i]]=0;
rep(i,0,n/x) L[i*x]=R[i*x]=vis[i*x]=0;return ans;
}
inline int solve(int x,int n,int k,int s)
{
int m=0;rep(i,1,n/x) m+=cnt[i*x];
if(k==1) return s%x==0?cnt[s]:0;
else if(k==2) { int ans=0;for(int i=x;i<=n;i+=x) if((s^i)%x==0) ans=(ans+(lint)cnt[i]*cnt[i^s])%mod;return ans; }
else if(k==3) { int ans=(F(x,n,3,s)-(3ll*m-3+1)*solve(x,n,1,s))%mod;if(ans<0) ans+=mod;return (lint)ans*inv6%mod; }
int ans=(F(x,n,4,s)-(4+6ll*(m-2))*solve(x,n,2,s))%mod;if(ans<0) ans+=mod;return (lint)ans*inv24%mod;
}
int main()
{
// freopen("data.in","r",stdin);
inv6=fast_pow(6,mod-2),inv24=fast_pow(24,mod-2);
int m=inn(),k=inn(),s=inn(),n=s,x,ans=0;
rep(i,1,m) cnt[x=inn()]++,n=max(n,x);prelude(n);
if(k==1) return solvek1(s);if(k==2) return solvek2(n,s);
rep(i,1,n) ans=(ans+(lint)phi[i]*solve(i,n,k,s))%mod;
return !printf("%d\n",ans);
}