算法作用
- 查询区间第k大,第k小
- 查询区间前k大之和,前k小之和
Code
ACWing255. 第K小数
区间第k小
const int N = 3e6 + 5, M = 1e5 + 5;
struct node
{
int lc, rc;//左右儿子编号
int cnt;//值域区间[l, r]中一共插入过多少数
}t[N];
#define tl t[p].lc
#define tr t[p].rc
#define tp t[p]
int root[M], tot;//可持久化线段树的每个根和总节点数
int n, a[M], b[M], num;
int build(int l, int r)
{
int p = ++tot;
if (l == r) { tp.cnt = 0; return p; }
int mid = l + r >> 1;
tl = build(l, mid), tr = build(mid + 1, r);
tp.cnt = t[tl].cnt + t[tr].cnt;
return p;
}
//单点修改(对于第i次修改,以可持久化线段树的第i-1个版本为基础)
int insert(int now, int l, int r, int x, int v)
{
int p = ++tot;
tp = t[now];
if (l == r)
{
tp.cnt += v;
return p;
}
int mid = l + r >> 1;
if (x <= mid) tl = insert(t[now].lc, l, mid, x, v);
else tr = insert(t[now].rc, mid + 1, r, x, v);
tp.cnt = t[tl].cnt + t[tr].cnt;
return p;
}
//在p,q两个节点上,区域为[l, r],求第k小数
int ask(int p, int q, int l, int r, int k)
{
if (l == r) return l;
int mid = l + r >> 1;
int lcnt = t[t[p].lc].cnt - t[t[q].lc].cnt;
if (k <= lcnt) return ask(t[p].lc, t[q].lc, l, mid, k);
else return ask(t[p].rc, t[q].rc, mid + 1, r, k - lcnt);
}
int main()
{
IOS;
int m; cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i], b[i] = a[i];
//离散化
sort(b + 1, b + n + 1);
num = unique(b + 1, b + n + 1) - b - 1;
root[0] = build(1, num);
for (int i = 1; i <= n; i++)
{
a[i] = lower_bound(b + 1, b + num + 1, a[i]) - b;//修改为离散化的值
root[i] = insert(root[i - 1], 1, num, a[i], 1);
}
//可持久化线段树中“以root[i]为根的线段树”的值域区间[l, r]
//保存了a的前i个数有多少个数落在值域区间[l, r]内
while (m--)
{
int l, r, k; cin >> l >> r >> k;
int ans = ask(root[r], root[l - 1], 1, num, k);
cout << b[ans] << endl;
}
return 0;
}
区间前k大之和
Code
struct node
{
int lc, rc;
int cnt;
ll sum;
}t[N];
#define tl t[p].lc
#define tr t[p].rc
#define tp t[p]
int root[M], tot;
int n, a[M], b[M], num;
int build(int l, int r)
{
int p = ++tot;
if (l == r) { tp.cnt = tp.sum = 0; return p; }
int mid = l + r >> 1;
tl = build(l, mid), tr = build(mid + 1, r);
tp.cnt = tp.sum = 0;
return p;
}
int insert(int now, int l, int r, int x, ll v)//位置x插入数值v
{
int p = ++tot;
tp = t[now];
if (l == r)
{
tp.cnt += 1;
tp.sum += v;
return p;
}
int mid = l + r >> 1;
if (x <= mid) tl = insert(t[now].lc, l, mid, x, v);
else tr = insert(t[now].rc, mid + 1, r, x, v);
tp.cnt = t[tl].cnt + t[tr].cnt;
tp.sum = t[tl].sum + t[tr].sum;
return p;
}
//区间前k大之和
ll ask(int p, int q, int l, int r, int k)
{
if (l == r) return (t[p].sum - t[q].sum) / (t[p].cnt - t[q].cnt) * k;//特别注意(一个点可能包含多个值)
int mid = l + r >> 1;
ll sum = 0;
int rcnt = t[t[p].rc].cnt - t[t[q].rc].cnt;
if (k <= rcnt) sum = ask(t[p].rc, t[q].rc, mid + 1, r, k);
else sum = ask(t[p].lc, t[q].lc, l, mid, k - rcnt) + t[t[p].rc].sum - t[t[q].rc].sum;
return sum;
}