题目链接:点我啊╭(╯^╰)╮
题目大意:
给定
a
i
a_i
ai,找到一组
b
i
b_i
bi,满足
0
≤
b
i
≤
a
i
0 \le b_i \le a_i
0≤bi≤ai,且
∑
b
i
=
k
\sum{b_i} = k
∑bi=k
使得
∑
b
i
×
(
a
i
−
b
i
2
)
\sum{b_i \times (a_i - b_i^2)}
∑bi×(ai−bi2) 最大
解题思路:
其实一开始想的是错的,但是和正解其实是同一种思路:
先让
b
i
=
a
i
b_i = \sqrt{a_i}
bi=ai,然后讨论
∑
b
i
\sum{b_i}
∑bi 和
k
k
k 的大小关系
如果
∑
b
i
≤
k
\sum{b_i} \le k
∑bi≤k,说明要增大
b
i
b_i
bi,但是此时增大
b
i
b_i
bi 只会让答案减少
但是发现随着
b
i
b_i
bi 越来越大,由
b
i
−
1
b_i-1
bi−1 到
b
i
b_i
bi 使答案减少的值越来越大
即
f
(
b
i
)
−
f
(
b
i
−
1
)
f(b_i) - f(b_i-1)
f(bi)−f(bi−1) 递减
因此这里对
f
(
b
i
)
−
f
(
b
i
−
1
)
f(b_i) - f(b_i-1)
f(bi)−f(bi−1) 二分一个
x
x
x
若
f
(
b
i
)
−
f
(
b
i
−
1
)
≥
x
f(b_i) - f(b_i-1) \ge x
f(bi)−f(bi−1)≥x ,要减少
f
(
b
i
)
−
f
(
b
i
−
1
)
f(b_i) - f(b_i-1)
f(bi)−f(bi−1),因此增大
b
i
b_i
bi
最后判断
∑
b
i
≥
k
\sum{b_i} \ge k
∑bi≥k,若满足则增大
x
x
x,表示减少使答案减少的值
然后讨论
∑
b
i
≥
k
\sum{b_i} \ge k
∑bi≥k,说明要减少
b
i
b_i
bi
对
f
(
x
)
=
x
×
(
a
−
x
2
)
f(x) = x \times (a -x^2)
f(x)=x×(a−x2) 求导,得
f
′
(
x
)
=
−
3
x
2
+
a
f'(x) = -3x^2 + a
f′(x)=−3x2+a
不是单调的,因此不能二分,我就卡在了这里
正解是直接对增值进行二分,在上面我们说到使答案减少的值越来越大
其实等价于使答案增大的值越来越小,观察
f
′
(
x
)
=
−
3
x
2
+
a
f'(x) = -3x^2 + a
f′(x)=−3x2+a
f
′
(
x
)
f'(x)
f′(x) 在
x
>
0
x>0
x>0 时单调递减,说明
f
(
x
)
f(x)
f(x) 在
x
>
0
x>0
x>0 时增值越来越小,也就是使答案增大的值越来越小
然后就没有了!!!
直接套用上面那个方法(代码)就行了,唯一不同的是现在是对增值二分
当
b
i
b_i
bi 的增值大于二分的
x
x
x 时,就让它增加,也就减少了增值,这样一定保证答案的合理性
只是最后要处理一下
∑
b
i
≥
k
\sum{b_i} \ge k
∑bi≥k,选取最优的
∑
b
i
−
k
\sum{b_i} - k
∑bi−k 个
b
i
b_i
bi 使其减少
时间复杂度:
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
typedef pair <int,int> pii;
const int maxn = 3e5 + 5;
int n;
ll k, a[maxn], b[maxn];
struct node{
ll num; int id;
bool operator < (const node &A){
return num > A.num;
}
} p[maxn];
bool ck(ll x){
ll all = 0, l, r, mid;
for(int i=1; i<=n; i++){
l = 0, r = a[i], mid;
while(l <= r){
mid = l + r >> 1;
ll tmp1 = mid * (a[i] - mid * mid);
ll tmp2 = (mid - 1) * (a[i] - (mid - 1) * (mid - 1));
if(tmp1 - tmp2 >= x) l = mid + 1;
else r = mid - 1;
}
b[i] = r; all += b[i];
}
return all >= k;
}
void gao(){
ll all = 0, cnt = 0;
for(int i=1; i<=n; i++) {
all += b[i];
if(b[i] == 0) continue;
p[++cnt].id = i;
ll tmp1 = 1ll * b[i] * (a[i] - b[i] * b[i]);
ll tmp2 = 1ll * (b[i] - 1) * (a[i] - 1ll * (b[i] - 1) * (b[i] - 1));
p[cnt].num = tmp2 - tmp1;
}
all -= k; if(all == 0) return;
sort(p+1, p+1+cnt);
for(int i=1; all; i++, all--) b[p[i].id]--;
}
signed main() {
scanf("%d%lld", &n, &k);
for(int i=1; i<=n; i++) scanf("%lld", a+i);
ll l = -4e18, r = 4e18, mid;
while(l <= r){
mid = l + r >> 1;
if(ck(mid)) l = mid + 1;
else r = mid - 1;
}
ck(r); gao();
for(int i=1; i<=n; i++) printf("%d ", b[i]);
}