问题描述:
给定一个整数序列a[1..N],定义sum[i][j]=a[i]+a[i+l]+……+a[j],将所有的sum[i][j]从小到大排序(其中i,j满足1<=i<=j<=N),得到一个长为N*(N+1)/2的序列,求该序列中的第k个元素。
输入格式(ktm.in)
第一行有两个整数N,k,其中0<N<=20000,1<=k<=N*(N+1)/2,数据保证任何一个sum[i][j]的绝对值不超过2^30。
接下来N行每行一个整数。顺序给出序列a的元素。
输出格式(kth.out)
sum序列中的第k个元素
输入样例
5 15
1
2
3
4
5
输出样例
15
30%的数据满足n<=1000;
另外30%的数据满足n<=20000,k<=100000并且所有元素都为正整数。
这道题我一开始受 ccl 的影响,以为是 k 短路问题,但是看到 k 的范围瞬间觉得不科学。
接着 lyp 跟我讲了二分答案的做法,觉得非常厉害。
也就是说,每二分出一个值,便到所有的区间和中查找他的 rank,直到出解为止。
但是还有一个问题:序列中存在负数,也就是说区间和并不是单调的的。那么如何查找 rank 呢?
我的实现是使用平衡树(treap),每次查找的复杂度为 O(logn)。
这样一来,二分答案复杂度为 O(log(1 << 31)), 枚举区间端点复杂度为 O(n),查找为 O(logn),总复杂度为 O(nlogn2)。
总结:又学会了一种问题的解决方法。
Code :
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <climits>
#include <iostream>
#include <algorithm>
typedef int MAINTYPE;
typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long uint64;
#define swap(a, b, t) ({t _ = (a); (a) = (b); (b) = _;})
#define MAX(a, b, t) ({t _ = (a), __ = (b); _ > __ ? _ : __;})
#define MIN(a, b, t) ({t _ = (a), __ = (b); _ < __ ? _ : __;})
#define max(a, b) MAX(a, b, MAINTYPE)
#define min(a, b) MIN(a, b, MAINTYPE)
#define maxn 200105
#define random (rand() * rand())
#define sum(i, j) (s[j] - s[(i) - 1])
int64 n, k;
int a[maxn], s[maxn];
int mini, maxi, ans, vetot;
struct node{node * c[2]; int data, rank, size;} vess[maxn], * root;
void update(node * & p)
{
if (! p) return;
p->size = (p->c[0] ? p->c[0]->size : 0) + (p->c[1] ? p->c[1]->size : 0) + 1;
}
void rotate(node * & p, bool flag)
{
node * q = p->c[! flag];
p->c[! flag] = q->c[flag], q->c[flag] = p;
update(p), update(q), p = q;
}
void insert(node * & p, int data)
{
if (! p)
{
p = vess + ++ vetot;
p->data = data, p->rank = random, p->size = 1;
p->c[0] = p->c[1] = 0;
}
else
{
bool flag = data > p->data;
insert(p->c[flag], data), ++ p->size;
if (p->c[flag]->rank < p->rank) rotate(p, ! flag);
}
}
int getrank(int data)
{
node * p = root;
int rank = 0;
while (p)
{
rank += p->c[0] ? p->c[0]->size + 1: 1;
if (data <= p->data)
{
rank -= p->c[0] ? p->c[0]->size + 1: 1;
p = p->c[0];
}
else
p = p->c[1];
}
return rank;
}
void work()
{
if (k == 1) return (void) printf("%d\n", mini);
if (k == n * (n + 1) / 2) return (void) printf("%d\n", maxi);
for (int l = mini, r = maxi; l < r; )
{
int mid = (l + r) >> 1, rank = 0;
vetot = 0, root = 0;
for (int i = n; i >= 1; -- i)
{
insert(root, s[i]);
rank += getrank(mid + s[i - 1]);
}
rank < k ? (l = mid + 1) : (r = mid);
ans = l - 1;
}
printf("%d\n", ans);
}
void prepare()
{
mini = INT_MAX, maxi = INT_MIN;
for (int i = 1, s = 0, t = 0; i <= n; ++ i)
{
mini = min(mini, s += a[i]);
maxi = max(maxi, t += a[i]);
if (s > 0) s = 0;
if (t < 0) t = 0;
}
}
void init()
{
scanf("%I64d%I64d", & n, & k);
for (int i = 1; i <= n; ++ i)
scanf("%d", & a[i]), s[i] = s[i - 1] + a[i];
}
int main()
{
freopen("kth.in", "r", stdin);
freopen("kth.out", "w", stdout);
init();
prepare();
work();
return 0;
}