题目描述
分析
先考虑一下暴力做法:
即枚举 a , b , c a,b,c a,b,c, 求出 d d d, 判断 a 2 + b 2 + c 2 + d 2 是 否 等 于 n a^2+b^2+c^2+d^2 是否等于 n a2+b2+c2+d2是否等于n, 大约 1 e 9 1e9 1e9的复杂度,肯定会超时
下面考虑如何优化暴力做法:
不妨以空间换时间
:
先枚举 c , d c,d c,d的所有可能,储存下来,再枚举 a , b a,b a,b,查找之前存储下来的 c , d c,d c,d,是否有符合要求的,从而得到答案
(符合要求即: a 2 + b 2 + c 2 + d 2 = n a^2+b^2+c^2+d^2 = n a2+b2+c2+d2=n 且 a , b , c , d a,b,c,d a,b,c,d按联合主键上升)
如何查找出符合要求的 c , d c,d c,d呢? 这里有两种做法(二分/哈希
)
二分
O ( n 2 l o g n ) O(n^2logn) O(n2logn)
对储存下来的 c , d c,d c,d,按 c 2 + d 2 , c , d c^2+d^2,c,d c2+d2,c,d的优先级升序排列, 而后二分出一个最小的符合要求的 r e s res res,使 r e s = n − a 2 − b 2 res = n - a^2 - b^2 res=n−a2−b2, 从而求出 c , d c,d c,d
至于如何满足联合主键上升: 枚举 a , b a,b a,b, 将 c , d c,d c,d按上述规则排序即可
哈希
O ( n 2 ) O(n^2) O(n2):
建立 r e s res res到 c c c的哈希映射, 直接 O ( 1 ) O(1) O(1)找出符合要求的 r e s res res, 使 r e s = n − a 2 − b 2 res = n - a^2 - b^2 res=n−a2−b2, 从而求出 c , d c,d c,d.
而建立哈希映射的过程,即可保证 c , d c,d c,d按联合主键上升
N < 5 e 6 N<5e6 N<5e6, 可以直接利用数组建立映射, 不要使用unodered_map
,会超时.
Y总视频讲解(需要权限)
实现
枚举 + 二分
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 5e6 + 9;
struct node
{
int c, d, sum;
};
node no[N*10];
int n, cnt;
bool cmp(node x, node y)
{
if(x.sum != y.sum) return x.sum < y.sum;
if(x.c != y.c) return x.c < y.c;
if(x.d != y.d) return x.d < y.d;
}
int solve(int res)
{
int left = 0, right = cnt;
while(left < right)
{
int mid = (left + right) / 2;
if(no[mid].sum >= res) right = mid;
else left = mid + 1;
}
if(no[left].sum == res) return left;
return -1;
}
int main()
{
cin >> n;
for(int i=0; i*i <= n; i++)
{
for(int j=i; j*j + i*i <=n; j++)
{
no[cnt++] = {i, j, i*i + j*j};
}
}
sort(no,no+cnt,cmp);
for(int i=0; i*i <= n; i++)
{
for(int j=i; i*i + j*j <=n; j++)
{
int res = n - i*i - j*j;
int index = solve(res);
if(index == -1) continue;
cout << i << " " << j << " " << no[index].c << " " << no[index].d << endl;
return 0;
}
}
return 0;
}
枚举 + 哈希
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int N = 5e6 + 9;
int vis[N];
int map[N];
int n;
int main()
{
cin >> n;
for(int i=0; i*i<=n; i++)
{
for(int j=i; j*j +i*i <= n; j++)
{
int sum = i*i + j*j;
if(!vis[sum])
{
map[sum] = i;
vis[sum] = 1;
}
}
}
for(int i=0; i*i<=n; i++)
{
for(int j=i; j*j + i*i <= n; j++)
{
int sum = n - i*i - j*j;
if(vis[sum])
{
int k = map[sum];
int l = sqrt(n - i*i - j*j - k*k);
cout << i << " " << j << " " << k << " " << l << endl;
return 0;
}
}
}
return 0;
}