根据题目条件可以推得
f(2n) = 3 * f(n) , f(2n + 1) = f(2n) + 1 = 3 * f(n) + 1 .
也就是把一个二进制数直接看成三进制数,接着就是从 1 到 n(3进制)中模 k 余 (0 ~ k-1 , 十进制)的数的个数的 xor。(看题解的,本来我也没想到)
用数位dp就可以解决了,刚开始的数位dp我写得有点问题,照着题解(我看的第一个题解数位dp写得有问题,我Re后看程序第一眼觉得有问题,想了一想觉得没问题==,还是太渣,连基础数位dp都不会)折腾了我一个下午和晚上。
PS : 请在hihocoder交==
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <stack>
#include <vector>
#include <cstring>
#include <map>
#include <queue>
#define msc(X) memset(X,-1,sizeof(X))
#define ms(X) memset(X,0,sizeof(X))
typedef long long LL;
using namespace std;
int bit[70],fac[70],wh;
LL dp[4][70][65537];
LL dfs(int w,int cur,int k,bool lmt)
{
if(w==0) return cur==0;
if(!lmt&&dp[wh][w][cur]!=-1) return dp[wh][w][cur];
int x=(lmt?bit[w]:1);
LL ans=0;
for(int i=0;i<=x;i++)
ans+=dfs(w-1,(cur-i*fac[w-1]+k)%k,k,lmt&&i==x);
if(!lmt) dp[wh][w][cur]=ans;
return ans;
}
int main(int argc, char const *argv[])
{
int t;
scanf("%d",&t);
fac[0]=1;
msc(dp);
while(t--)
{
LL n;int k;
scanf("%lld%d",&n,&k);
if(k==3) printf("%lld\n",(n>>1)^((n+1)>>1) );
else
{
int top=0;
if(k==5) wh=0;
else if(k==17) wh=1;
else if(k==257) wh=2;
else wh=3;
while(n) bit[++top]=(n&1),n>>=1;
for(int i=1;i<70;i++)
fac[i]=fac[i-1]*3%k;
for(int i=0;i<k;i++)
n^=dfs(top,i,k,true)-(i==0);
printf("%lld\n",n );
}
}
return 0;
}