大致题意
给一个NxNxN的立方体,里面的点的范围是(0<=x<=N-1,0<=y<=N-1,0<=z<=N-1)。每个点的点值是x xor y xor z。现要求延平行于坐标轴的非整数点切割立方体,直至切割成NxNxN个单位立方体。每次切割的价值是两个被分开的立方体的价值之和的乘积。求最大价值。
思路
可以想到对于每个点,不断切割,他都要和所有没跟他分到一个立方体的点进行乘积,直到最后自己一个点成为一个单位立方体。所以,每个点都和其他任意点算了一次乘积。因此价值和切割方式无关,是个定值。若ai表示点i的价值,那么容斥一下总的答案就是 ans=1/2*[sum(ai) * sum(ai)-sum(ai*ai)]。然后用FWT计算出每种价值ai出现的次数。然后即可计算。
代码
贴一下FWT的代码,(蒟蒻也是头一次用)好像FWT要求的数组长度和FFT不同。
详细关于FWT的介绍可以看这篇巨巨的博客:https://www.cnblogs.com/cjyyb/p/9065615.html
#include<bits/stdc++.h>
using namespace std;
#define maxn 3000005
#define maxm 3000006
#define ll long long int
#define INF 0x3f3f3f3f
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
#define mem(a) memset(a,0,sizeof(a))
#define sqr(x) (x*x)
#define inf (ll)2e18+1
#define PI acos(-1)
#define mod 998244353
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}
int n,a[maxn],len;
ll inv2;
ll fast(ll x,ll y){
ll res=1ll;
while(y){
if(y&1)res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
void FWT(int *P,int len,int opt)
{
for(int i=2;i<=len;i<<=1)
for(int p=i>>1,j=0;j<len;j+=i)
for(int k=j;k<j+p;++k)
{
int x=P[k],y=P[k+p];
P[k]=(x+y)%mod;P[k+p]=(x-y+mod)%mod;
if(opt==-1)P[k]=1ll*P[k]*inv2%mod,P[k+p]=1ll*P[k+p]*inv2%mod;
}
}
int main()
{
inv2=fast(2ll,mod-2);
while(~scanf("%d",&n)){
for(len=1;len<=n;len<<=1);
inc(i,0,len-1)a[i]=(i<n);
FWT(a,len,1);
inc(i,0,len-1)a[i]=1ll*a[i]*a[i]%mod*a[i]%mod;
FWT(a,len,-1);
ll ans=0,tmp=0;
inc(i,0,len-1){ans=(ans+1ll*i*a[i]%mod)%mod;tmp=(tmp+1ll*i*i%mod*a[i]%mod)%mod;}
ans=ans*ans%mod;ans=(ans-tmp+mod)%mod;
printf("%lld\n",ans*inv2%mod);
}
return 0;
}