题目
题目描述
在一个果园里,多多已经将所有的果子打了下来,而且按果子的不同种类分成了不同的堆。多多决定把所有的果子合成一堆。
每一次合并,多多可以把两堆果子合并到一起,消耗的体力等于两堆果子的重量之和。可以看出,所有的果子经过 n − 1 n-1 n−1 次合并之后, 就只剩下一堆了。多多在合并果子时总共消耗的体力等于每次合并所耗体力之和。
因为还要花大力气把这些果子搬回家,所以多多在合并果子时要尽可能地节省体力。假定每个果子重量都为 1 1 1 ,并且已知果子的种类 数和每种果子的数目,你的任务是设计出合并的次序方案,使多多耗费的体力最少,并输出这个最小的体力耗费值。
例如有 3 3 3 种果子,数目依次为 1 1 1 , 2 2 2 , 9 9 9 。可以先将 1 1 1 、 2 2 2 堆合并,新堆数目为 3 3 3 ,耗费体力为 3 3 3 。接着,将新堆与原先的第三堆合并,又得到新的堆,数目为 12 12 12 ,耗费体力为 12 12 12 。所以多多总共耗费体力 = 3 + 12 = 15 =3+12=15 =3+12=15 。可以证明 15 15 15 为最小的体力耗费值。
输入格式
共两行。
第一行是一个整数
n
(
1
≤
n
≤
10000
)
n(1\leq n\leq 10000)
n(1≤n≤10000) ,表示果子的种类数。
第二行包含 n n n 个整数,用空格分隔,第 i i i 个整数 a i ( 1 ≤ a i ≤ 20000 ) a_i(1\leq a_i\leq 20000) ai(1≤ai≤20000) 是第 i i i 种果子的数目。
输出格式
一个整数,也就是最小的体力耗费值。输入数据保证这个值小于 2 31 2^{31} 231 。
样例 #1
样例输入 #1
3
1 2 9
样例输出 #1
15
提示
对于 30 % 30\% 30% 的数据,保证有 n ≤ 1000 n \le 1000 n≤1000:
对于 50 % 50\% 50% 的数据,保证有 n ≤ 5000 n \le 5000 n≤5000;
对于全部的数据,保证有 n ≤ 10000 n \le 10000 n≤10000。
代码
就一个优先队列,每次取两个加一下放回去算一下总和就好了,stl直接秒杀。
这里还是手搓一个小根堆堆来练练手吧
这里我不用数组第一个空间,这样寻找一个元素的父节点就直接是这个元素下标除2,感觉比较方便。
lower函数写的好丑,之后再想想能怎么写的简洁点吧。
#include<bits/stdc++.h>
using namespace std;
int heap[114514], top = 1;
void upper(int i) {
if (i == 1 || heap[i] > heap[i/2]) return;
swap(heap[i], heap[i / 2]);
upper(i / 2);
}
void lower(int i) {
if (2 * i > top - 1) return;
// 无右子树
if (2 * i + 1 > top - 1) {
if (heap[i] < heap[2*i]) return;
swap(heap[i], heap[2 * i]);
lower(2 * i);
}
// 有右子树
else {
// 位置正确
if (heap[i] < heap[2 * i] && heap[i] < heap[2 * i + 1]) return;
// 左大右小
else if (heap[i] >= heap[2 * i] && heap[i] < heap[2 * i + 1]) {
swap(heap[i], heap[2 * i]);
lower(2 * i);
}
// 左小右大
else if (heap[i] < heap[2 * i] && heap[i] >= heap[2 * i + 1]) {
swap(heap[i], heap[2 * i + 1]);
lower(2 * i + 1);
}
// 都小
else {
// 右更小
if (heap[2 * i] > heap[2 * i + 1]) {
swap(heap[i], heap[2 * i + 1]);
lower(2 * i + 1);
}
// 左更小
else {
swap(heap[i], heap[2 * i]);
lower(2 * i);
}
}
}
}
int pop() {
if (top == 1) return -1;
int ans = heap[1];
heap[1] = heap[--top];
lower(1);
return ans;
}
void push(int x) {
heap[top] = x;
upper(top++);
}
int length() {
return top - 1;
}
int main() {
int n, temp, ans = 0;
cin >> n;
for (int i = 0; i < n; i++) {
cin >> temp;
push(temp);
}
while (length() != 1) {
temp = pop() + pop();
ans += temp;
push(temp);
}
cout << ans << endl;
return 0;
}