Sort
Description
Recently, Bob has just learnt a naive sorting algorithm: merge sort. Now, Bob receives a task from Alice.
Alice will give Bob N sorted sequences, and the i-th sequence includes ai elements. Bob need to merge all of these sequences. He can write a program, which can merge no more than k sequences in one time. The cost of a merging operation is the sum of the length of these sequences. Unfortunately, Alice allows this program to use no more than T cost. So Bob wants to know the smallest k to make the program complete in time.
Input
The first line of input contains an integer t0, the number of test cases. t0 test cases follow.
For each test case, the first line consists two integers N (2≤N≤100000) and T (∑Ni=1ai<T<231).
In the next line there are N integers a1,a2,a3,…,aN(∀i,0≤ai≤1000).
Output
For each test cases, output the smallest k.
Sample Input
1
5 25
1 2 3 4 5
Sample Output
3
题意
给定n堆序列,和价值T,在价值T之内,每次合并最多m堆序列,每次合并花费合并序列长度的价值,最后合并为一堆,问在T范围内m的最小值。
思路
一次合并k个最优的代价,用k叉哈夫曼做,在合并过程中可以发现,拿出1与k - 1堆合并,就能形成1堆,然后这堆再和k - 1堆合并,依次这样下去,但在最后一步合并可能会出现不够的现象,于是需要在最前面补0。
根据上面的规律,就能推出一个式子(n - 1) % (k - 1),如果刚刚为0,则恰好合并完,不做处理,否则最后一次合并不够k堆,于是需要补 k - (n - 1) % (k - 1) - 1堆,但是不能在最后面补,在合并前补才是最优的。
这个题有个坑点,不能再前面补0,k枚举过大时,会T掉,第一次合并我们可以这样处理,前面x = k - (n - 1) % (k - 1) - 1个0对答案不做贡献,记录y = (n - 1) % (k - 1) + 1 的前缀和,直接作为第一次合并即可, 这里x + y = k.。然后还需合并(n - 1)/(k - 1)次。
时间复杂度:O(n * (logn) ^ 2)
AC代码
#include<cstdio>
#include<queue>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 5;
int arr[maxn], sum[maxn];
int n, k;
priority_queue<int, vector<int>, greater<int> > q;
bool judge(int m){
while (!q.empty()) q.pop();
int x = (n - 1) % (m - 1);
ll res = 0;
if (x){
x++;
res += sum[x];
q.push(sum[x]);
}
//n logn
for (int i = x + 1; i <= n; ++i){
q.push(arr[i]);
}
//复杂度(n - 1)/(m - 1) * m * logn == > (n - 1) * logn
for (int i = 0; i < (n - 1) / (m - 1); ++i){
int temp = 0;
int loop = m;
while (loop--){
temp += q.top(); q.pop();
}
res += temp;
q.push(temp);
if (res > k) return false;
}
return res <= k;
}
void solve(){
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i){
scanf("%d", &arr[i]);
}
sort(arr + 1, arr + n + 1);
sum[0] = 0;
for (int i = 1; i <= n; ++i){
sum[i] = sum[i - 1] + arr[i];
}
int l = 1, r = n;
while (r - l > 1){
int m = l + (r - l) / 2;
if (judge(m)) r = m;
else l = m;
}
printf("%d\n", r);
}
int main(){
int t;
scanf("%d", &t);
for (int i = 0; i < t; ++i) solve();
return 0;
}
/**
10
5 100
1 2 3 4 5
5 25
1 2 3 4 5
3 100
1 2 3
2 10
1 2
**/