E. Carrots for Rabbits
time limit per test
1 second
memory limit per test
256 megabytes
input
standard input
output
standard output
There are some rabbits in Singapore Zoo. To feed them, Zookeeper bought nn carrots with lengths a1,a2,a3,…,ana1,a2,a3,…,an. However, rabbits are very fertile and multiply very quickly. Zookeeper now has kk rabbits and does not have enough carrots to feed all of them. To solve this problem, Zookeeper decided to cut the carrots into kk pieces. For some reason, all resulting carrot lengths must be positive integers.
Big carrots are very difficult for rabbits to handle and eat, so the time needed to eat a carrot of size xx is x2x2.
Help Zookeeper split his carrots while minimizing the sum of time taken for rabbits to eat the carrots.
Input
The first line contains two integers nn and kk (1≤n≤k≤105)(1≤n≤k≤105): the initial number of carrots and the number of rabbits.
The next line contains nn integers a1,a2,…,ana1,a2,…,an (1≤ai≤106)(1≤ai≤106): lengths of carrots.
It is guaranteed that the sum of aiai is at least kk.
Output
Output one integer: the minimum sum of time taken for rabbits to eat carrots.
Examples
input
Copy
3 6 5 3 1
output
Copy
15
input
Copy
1 4 19
output
Copy
91
Note
For the first test, the optimal sizes of carrots are {1,1,1,2,2,2}. The time taken is 12+12+12+22+22+22=15
For the second test, the optimal sizes of carrots are {4,5,5,5}. The time taken is 42+52+52+52=91.
题意:
有 n 根胡萝卜,给出每根胡萝卜的长度,现要把这些胡萝卜切成 k 份,分给 k 只兔子(每只兔子一根胡萝卜),兔子吃胡萝卜需要花费的时间为该段胡萝卜的长度的平方,求所有兔子吃完胡萝卜的最短时间
思路:
简单来说,把 n 个数分为 k 个数(n <= k),求最小平方和
考虑划分一个数,最优情况即每次划分都尽量平均,划分的份数越多平方和越小。用优先队列维护每个数在当前基础上再多划分一次所造成的差值,差值大的优先划分
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int N = 1e5 + 7;
struct node {
ll x, y, d;
bool operator < (const node &a) const {
return d < a.d;
}
};
ll solve(ll a, ll b) {
ll x = a / b;
ll y = a % b;
return x * x * (b - y) + y * (x + 1) * (x + 1);
}
int main(){
ll n, k, a;
while(~scanf("%lld%lld", &n, &k)) {
priority_queue<node>q;
ll ans = 0;
for(ll i = 1; i <= n; ++i) {
scanf("%lld", &a);
q.push(node{a, 1, solve(a, 1) - solve(a, 2)});
ans += a * a;
}
while(n < k) {
++n;
node top = q.top();
q.pop();
ans -= top.d;
q.push(node{top.x, top.y + 1, solve(top.x, top.y + 1) - solve(top.x, top.y + 2)});
}
printf("%lld\n", ans);
}
return 0;
}