先不考虑mod k的情况,假设题目条件为a+b^2=c^3。由于1<=a<=b<=c<=n,则a+b^2的范围为[1+b^2 , b+b^2]。对于某个确定的c,我们可以从1到c枚举b,计算c落在多少个[1+b2 , b+b2]中即为方案数。这样做的复杂度是 O(n^2)。
我们可以更高效地利用 1<=a<=b<=c<=n 这个条件,从1到n枚举c。设当前处理的 c 的值为i,使用树状数组维护a+b^2 ,令区间[1+i^2 , i+i^2]加一。之后只需查询i(即c)的值即可。
考虑 mod k 的情况,即a+b^2≡c^3 (mod k)。做法是类似的。从小到大枚举c。使用树状数组维护a+b^2,每次令[(1+i^2)mod k,(i+i^2)mod k]加一。不同之处在于,由于取模的缘故,这个区间可能跨越整个[0,k−1]多次(即 i>=k 的情况)。此时只需在以后计算 c 的答案时都加上⌊i/k⌋即可。时间复杂度 O(n logn)。
#include<bits/stdc++.h>
#define lowbit(x) (x&-x)
#define N 100010
using namespace std;
int T,n,k,tree[N];
long long ans,cnt;
inline void add(int x,int num){
++x;
for(int i=x;i<=k;i+=lowbit(i)) tree[i]+=num;
}
int search(int x){
++x; int re=0;
for(int i=x;i;i-=lowbit(i)) re+=tree[i];
return re;
}
void Insert(int l,int r){
if(l>r){
Insert(l,k+1);
Insert(0,r);
}
else{
add(l,1);
add(r+1,-1);
}
}
int main(){
scanf("%d",&T);
for(int ii=1;ii<=T;++ii){
scanf("%d%d",&n,&k);
ans=cnt=0;
memset(tree,0,sizeof tree);
for(int i=1;i<=n;++i){
cnt+=i/k;
if(i%k) Insert((1ll*i*i+1)%k,(1ll*i*i+i)%k);
ans+=cnt+search(1ll*i*i*i%k);
}
printf("Case %d: %lld\n",ii,ans);
}
return 0;
}