题意:
给定一个整数n,我们要将这个整数分为四个非负整数的的平方和,输出字典序最小的一组解
思路:
一、确定方向
观察数据范围是5e6
,因此我们要将时间复杂度控制在O(n)
以内或者O(nlogn)
以内,本题可以考虑用二分来做,也可以用哈希。不过如果用哈希表做法的话则一定不能用STL
中的map
,因为常数过大容易超时,应当用 unordered_map
或 手写哈希。
二分的复杂度是O(nlogn)
,而手写哈希是O(n)
的。两个算法的思想其实是一致的。
二、具体分析
先用 二分 来分析一下这道题。
先考虑一下 a, b, c, d
四个数的范围,由于 N<=5e6
,因此每个数一定满足 <=sqrt(N)≈2200
,这就提示我们最多只能枚举其中两个数(应该有这样对数据的敏感度),如果枚举3
个数(2200^3≈8e9
)可能会超时,
这样我们得到了第一个信息:①至多枚举其中 2
个数。
但是实际上我们需要枚举3
个数(第4
个数d
可以由其它3
个数推出:d = sqrt(N - a ^ 2 - b ^ 2 - c ^ 2)
),那么如何解决这个问题呢?
得到第二个信息:②解决上述问题的常用策略:用空间换时间。具体做法如下:
- 先枚举
c
和d
的所有组合:
for( c = 0; c*c<=N; ++c )
for( d = c; c*c + d*d <= N; ++d )
//d从c开始枚举即可,因为四元组不考虑顺序,只需枚举组合数即可,这里人为规定一个顺序:d>=c
- 将上面两重循环产生的所有
c*c+d*d
的结果存起来(时间复杂度O(n^2)
,n<=2200
,可以接受),之后再枚举a
和b
:
for( a = 0; a*a<=N; ++a )
for( b = a; b*b + a*a <= N; ++b )
- 此时可以算一下:与最终的答案
n
的差值是多少。设t = n - a ^ 2 - b ^ 2
。此时只需判断t
是否在前面算出来的c ^ 2 + d ^ 2
中出现过,就说明找到了一组解,
这样一来,我们就将原本是 O(n^3)
的朴素的复杂度优化成了 O(n ^ 2 * O(if))
(O(if)
表示的是进行上方判断的时间复杂度)
那么对于判断 “某一个数” 在 “某一堆数”中 是否出现过,一般来说有两种做法:
- ①:哈希表
O(1)
,总时间复杂度O(n^2)
- ②:二分
O(logn)
,总时间复杂度O(n^2 * logn)
(需要对前面枚举出的结果集按从小到大排序)
不管是哪种做法,都可以将时间效率大大降低,都是可以通过本题的。
三、细节处理
不过,本题还有一些细节处理,比如:如何找到一组字典序最小的解。
首先,前两个数 a
和 b
已经能保证字典序最小了,因为 a
是从 0
开始从小到大枚举,第一个数 a
可保证最小,b
从 a
开始也是从小到大枚举(b>=a
), b
也能够保证最小,
至此前两个数 a
、b
就一定能保证字典序最小。
此外还要保证后两个数 c
和 d
字典序最小,在判断 t
是否存在于 结果集 中时,同时要保证 t
由字典序最小的 c
和 d
组合而成,
因此,在对 c ^ 2 + d ^ 2
排序的时候其实要对包含 3
个元素的结构体(c ^ 2 + d ^ 2
、c
、d
)按照字典序进行排序,
(具体来说,先对 c ^ 2 + d ^ 2
排序,如果 c ^ 2 + d ^ 2
相同则对 c
排序,如果 c
相同,则对 d
排序)
排完序后,在结构体数组中二分查找 t = n - a ^ 2 - b ^ 2
时,只要找到大于等于 t
的 min
即可(一旦找到,则字典序最小)
时间复杂度:(二分 O(n * logn))
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 5e6+10;
int n;
int cnt;
struct Node
{
int res, c, d;
bool operator < (const Node &t)const{
if(res!=t.res) return res<t.res;
if(c!=t.c) return c<t.c;
return d<t.d;
}
} node[N];
int main()
{
cin>>n;
//先枚举c和d的所有组合
for(int c = 0; c*c<=N; ++c)
for(int d = c; c*c+d*d<=N; ++d)
node[cnt++] = {c*c+d*d, c, d};
sort(node, node+cnt);
//之后再枚举 a和b,每枚举到一个值 则在上方产生的 c^2+d^2结果集中 二分
for(int a = 0; a*a<=N; ++a)
for(int b = a; a*a+b*b<=N; ++b)
{
int t = n - a*a - b*b;
int l = 0, r = cnt-1;
while(l<r)
{
int mid = l+r>>1;
if(node[mid].res>=t) r = mid;
else l= mid+1;
}
if(node[l].res==t)
{
printf("%d %d %d %d\n", a, b, node[l].c, node[l].d);
return 0;
}
}
return 0;
}