地址:http://codeforces.com/contest/1056/problem/B
思路:题意找出n*n中 (i*i+j*j)%m==0的个数,n<1e9,m<1e3,发现遍历的话肯定超时,而m<1e3,那么考虑是不是为m*m的时间复杂度,对于 (i*i+j*j)%m=0,若i>m,那么 (i-m)*(i-m)+j*j=(i*i+j*j+m*m-2*j*m) %m =0,因此i>m的答案都可以由 i0+km推过去,因此只要找m*m的i,j,在通过某个公式即可得到n*n的所有个数。
Code:
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long LL;
const int MAX_M=1e3+5;
LL n,m;
bool boo[MAX_M][MAX_M];
int main()
{
ios::sync_with_stdio(false);
while(cin>>n>>m){
memset(boo,0,sizeof(boo));
LL ans=0;
int pp=min(n,m);
for(int i=1;i<=pp;++i)
for(int j=1;j<=pp;++j)
if((i*i+j*j)%m==0&&!boo[i][j]){
int t1=m-((2*j)%m),t2=m-((2*i)%m);
LL s1=(n-j)/m+1,s2=(n-i)/m+1;
if(t1!=m&&j+t1<=n) s1+=(n-j-t1)/m+1;
if(t2!=m&&i+t2<=n) s2+=(n-i-t2)/m+1;
ans+=s1*s2;
if(j+t1<=m) boo[i][j+t1]=true;
if(i+t2<=m) boo[i+t2][j]=true;
if(i+t2<=m&&j+t1<=m) boo[i+t2][j+t1]=true;
}
cout<<ans<<endl;
}
return 0;
}