题意
给出n,m,k,p,求
(∑i=0n−1∑j=0m−1max((i⊕j)−k,0))modp
,其中
⊕
表示异或。多组数据。
T<=5000,n,m,k<=10^18,p<=10^9
分析
我的方法是设f[i,0/1,0/1,0/1]表示从高到底做到第i位,是否选出的第1个数是否卡着n的上界,选出的第2个数是否卡着m的上界,选出两个数的异或是否卡着k的下界。然后先转移一遍,然后再倒着跑一遍统计答案即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
int MOD,f[65][2][2][2],g[65][2][2][2];
LL bin[65],n,m,w;
void updata(int &x,int y)
{
x+=y;x-=x>=MOD?MOD:0;
}
int solve()
{
int ans=0;
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
f[60][1][1][1]=1;
for (int i=60;i>=1;i--)
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
for (int l=0;l<=1;l++)
if (f[i][j][k][l])
{
int x=f[i][j][k][l];
if (!j||(n&bin[i-1])) updata(f[i-1][j][(!(m&bin[i-1])&&k)?1:0][((w&bin[i-1])&&l)?1:0],x);
if (!k||(m&bin[i-1])) updata(f[i-1][(!(n&bin[i-1])&&j)?1:0][k][((w&bin[i-1])&&l)?1:0],x);
if (!l||!(w&bin[i-1]))
{
updata(f[i-1][(!(n&bin[i-1])&&j)?1:0][(!(m&bin[i-1])&&k)?1:0][l],x);
if ((!j||(n&bin[i-1]))&&(!k||(m&bin[i-1]))) updata(f[i-1][j][k][l],x);
}
}
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
for (int l=0;l<=1;l++)
g[0][j][k][l]=1,updata(ans,MOD-(LL)w%MOD*f[0][j][k][l]%MOD);
for (int i=1;i<=60;i++)
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
for (int l=0;l<=1;l++)
if (f[i][j][k][l])
{
int x=f[i][j][k][l];
if (!j||(n&bin[i-1]))
{
updata(ans,(LL)g[i-1][j][(!(m&bin[i-1])&&k)?1:0][((w&bin[i-1])&&l)?1:0]*x%MOD*(bin[i-1]%MOD)%MOD);
updata(g[i][j][k][l],g[i-1][j][(!(m&bin[i-1])&&k)?1:0][((w&bin[i-1])&&l)?1:0]);
}
if (!k||(m&bin[i-1]))
{
updata(ans,(LL)g[i-1][(!(n&bin[i-1])&&j)?1:0][k][((w&bin[i-1])&&l)?1:0]*x%MOD*(bin[i-1]%MOD)%MOD);
updata(g[i][j][k][l],g[i-1][(!(n&bin[i-1])&&j)?1:0][k][((w&bin[i-1])&&l)?1:0]);
}
if (!l||!(w&bin[i-1]))
{
updata(g[i][j][k][l],g[i-1][(!(n&bin[i-1])&&j)?1:0][(!(m&bin[i-1])&&k)?1:0][l]);
if ((!j||(n&bin[i-1]))&&(!k||(m&bin[i-1]))) updata(g[i][j][k][l],g[i-1][j][k][l]);
}
}
return ans;
}
int main()
{
bin[0]=1;
for (int i=1;i<=60;i++) bin[i]=bin[i-1]*2;
int T;scanf("%d",&T);
while (T--)
{
scanf("%lld%lld%lld%d",&n,&m,&w,&MOD);
n--;m--;
printf("%d\n",solve());
}
return 0;
}