题意:给n个数,从这n个数里,任意选一些数,使得这些数的平均值减去中位数尽可能的大
精度卡的好恶心啊。。。。。
首先发现集合个数一定是奇数,任何一个个数为偶数的集合,都可以通过去掉中间较大的那个数,来增大或不变这个值
如1 2 3 4 可以去掉3变成1 2 4
排序后,枚举每个数做中位数,然后三分集合个数
确定中位数之后,剩下的数肯定是越大越好,
会发现这个值并不是随着集合个数单调变化的,所以要三分求凸点。
直接double除法,会有精度问题,可以转换成乘法比较
换了好几种三分形式,都不对。。。。各种怀疑。。。。并不是很理解。。
ac代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
int a[maxn], n;
long long s[maxn];
long long fun(int x, int cnt){
int l = x - cnt;
int r = n - cnt + 1;
long long sum = s[n] - s[r - 1] + s[x] - s[l - 1];
return sum - 1ll * (2 * cnt + 1) * a[x];
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++){
scanf("%d", &a[i]);
}
sort(a + 1, a + n + 1);
if(n <= 2){
cout << 1 << endl;
cout << a[1] << endl;
return 0;
}
for(int i = 1; i <= n; i++){
s[i] = s[i - 1] + a[i];
}
long long minn = 0, ans_index = 1, ans_count = 0;
for(int i = 2; i < n; i++){
int le = 1, ri = min(i - 1, n - i);
long long minn_now = 0, ans = 0;
while(le <= ri){
int mid1 = (le + ri) >> 1;
int mid2 = (ri + 1 + mid1) >> 1;
long long f1 = fun(i, mid1);
long long f2 = fun(i, mid2);
if(f1 * (2 * mid2 + 1) < f2 * (2 * mid1 + 1)){
le = mid1 + 1;
}
else{
ri = mid2 - 1;
}
}
ans = le, minn_now = fun(i, le);
// cout << i << " " << ans << endl;
if(minn_now * (2 * ans_count + 1) > minn * (2 * ans + 1)){
ans_index = i;
ans_count = ans;
minn = minn_now;
}
}
cout << ans_count * 2 + 1 << endl;
for(int i = ans_index - ans_count; i <= ans_index; i++)
cout << a[i] << " ";
for(int i = n - ans_count + 1; i <= n; i++)
cout << a[i] << " ";
cout << endl;
return 0;
}
wa的代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
int a[maxn], n;
long long s[maxn];
long long fun(int x, int cnt){
int l = x - cnt;
int r = n - cnt + 1;
long long sum = s[n] - s[r - 1] + s[x] - s[l - 1];
return sum - 1ll * (2 * cnt + 1) * a[x];
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++){
scanf("%d", &a[i]);
}
sort(a + 1, a + n + 1);
if(n <= 2){
cout << 1 << endl;
cout << a[1] << endl;
return 0;
}
for(int i = 1; i <= n; i++){
s[i] = s[i - 1] + a[i];
}
long long minn = 0, ans_index = 1, ans_count = 0;
for(int i = 2; i < n; i++){
int le = 1, ri = min(i - 1, n - i);
long long minn_now = 0, ans = 0;
while(le < ri){
int mid1 = (2 * le + ri) / 3;
int mid2 = (le + 2 * ri) / 3;
long long f1 = fun(i, mid1);
long long f2 = fun(i, mid2);
if(f1 * (2 * mid2 + 1) <= f2 * (2 * mid1 + 1)){
le = mid1 + 1;
}
else{
ri = mid2 - 1;
}
}
ans = le, minn_now = fun(i, le);
if(minn_now * (2 * ans_count + 1) > minn * (2 * ans + 1)){
ans_index = i;
ans_count = ans;
minn = minn_now;
}
}
cout << ans_count * 2 + 1 << endl;
for(int i = ans_index - ans_count; i <= ans_index; i++)
cout << a[i] << " ";
for(int i = n - ans_count + 1; i <= n; i++)
cout << a[i] << " ";
cout << endl;
return 0;
}
也可以二分来做,比较一下mid+1和mid-1就可以了