引入
有这样一个问题:给定长度为 n n n 的序列,求其长度在 [ L , R ] [L,R] [L,R] 的前 K K K 大区间和,范围 n ≤ 1 0 5 n\le 10^5 n≤105。
思路
只规定了区间的长度范围,可是哪些才是前 K K K 大的区间呢?我觉得这应该是个小 trick,就是先做个前缀和,那区间 [ l , r ] [l,r] [l,r] 的和就是 s r − s l − 1 , s_r-s_{l-1}, sr−sl−1,开一个大根堆,对于每个右端点 p ≥ l p \ge l p≥l,求出其左端点 -1 的合法区间 [ max ( 0 , p − r ) , p − l ] [\max(0,p-r),p-l] [max(0,p−r),p−l] 的最小值,假设是 s k s_k sk,那答案就是 s p − s k s_p-s_k sp−sk,插入堆中。
每次取出值最大的那个,记入答案,假设取出的值的 s k s_k sk 是这个区间的第 t t t 小,那么就再找出第 t + 1 t+1 t+1 小的 s k ′ s_{k'} sk′,于是 s p − s k ′ s_p-s_{k'} sp−sk′ 就是第 t + 1 t+1 t+1 大,计入答案,插入堆中。当算完第 K K K 个时,就结束了。
至于算区间 kth,开棵主席树随便维护一下就好。
这里有道题: [NOI2010] 超级钢琴
Code:
/*
Program: P2048.cpp
Author: 1l6suj7
DateTime: 2024-01-25 18:09:06
Description:
*/
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lp(i, j, n) for(int i = j; i <= n; ++i)
#define dlp(i, n, j) for(int i = n; i >= j; --i)
#define mst(n, v) memset(n, v, sizeof(n))
#define mcy(n, v) memcpy(n, v, sizeof(v))
#define INF 1e18
#define MAX4 0x3f3f3f3f
#define MAX8 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define pll pair<ll, ll>
#define co(x) cerr << (x) << ' '
#define cod(x) cerr << (x) << endl
#define fi first
#define se second
#define eps 1e-8
#define lc(x) ((x) << 1)
#define rc(x) ((x) << 1 ^ 1)
#define pb(x) emplace_back(x)
using namespace std;
const int N = 500010, V = 5e8;
namespace Sg {
const int S = N * 100;
int tot, cnt[S], lc[S], rc[S];
int mdf(int x, int l, int r, int p, int val) {
int nw = ++tot; cnt[nw] = cnt[x], lc[nw] = lc[x], rc[nw] = rc[x];
if(l == r) return ++cnt[nw], nw;
int mid = l + r >> 1;
if(p <= mid) lc[nw] = mdf(lc[nw], l, mid, p, val);
else rc[nw] = mdf(rc[nw], mid + 1, r, p, val);
cnt[nw] = cnt[lc[nw]] + cnt[rc[nw]];
return nw;
}
ll kth(int L, int R, int l, int r, int rk) { // kth min
if(!R) return MAX8;
if(l == r) return l;
int mid = l + r >> 1, t = cnt[lc[R]] - cnt[lc[L]];
if(rk <= t) return kth(lc[L], lc[R], l, mid, rk);
else return kth(rc[L], rc[R], mid + 1, r, rk - t);
}
} using namespace Sg;
int n, K, L, R, rt[N], a[N];
ll ans = 0;
struct node {
int p, k; ll res;
bool operator < (const node & t) const { return res < t.res; }
};
priority_queue<node> q;
signed main() {
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
#ifndef READ
ios::sync_with_stdio(false);
cin.tie(0);
#endif
cin >> n >> K >> L >> R;
rt[1] = mdf(rt[1], -V, V, 0, 1);
lp(i, 1, n) cin >> a[i], a[i] += a[i - 1], rt[i + 1] = mdf(rt[i], -V, V, a[i], 1);
lp(i, L, n) {
ll t = a[i] - kth(rt[max(i - R, 0)], rt[i - L + 1], -V, V, 1);
// cod(t);
q.push({ i, 1, t });
}
int cnt = 0;
while(cnt < K) {
node fr = q.top(); q.pop();
ll t = a[fr.p] - kth(rt[max(fr.p - R, 0)], rt[fr.p - L + 1], -V, V, fr.k + 1);
// cod(t);
ans += fr.res, ++cnt;
q.push({ fr.p, fr.k + 1, t });
}
cout << ans;
return 0;
}
如果是异或和?
可以用类似的思路,先做个前缀异或和,答案就是 s r xor s l − 1 s_r\,\text{xor}\,s_{l-1} srxorsl−1,然后开个大根堆,先一次对每个右端点求出其对应区间的最大异或和,然后每次取出堆中最大值,重新计算答案。
什么,你说怎么找到第 k k k 大的 s r xor s l − 1 s_r \,\text{xor}\,s_{l-1} srxorsl−1?开棵 01-trie 维护一下就好,类似于线段树上二分在 01-trie 上二分,因为要求区间第 k 大,所以用可持久化 01-trie,和其他可持久化数据结构大同小异,糊一糊就出来了。
这里有道题:[十二省联考 2019] 异或粽子
Code:
/*
Program: P5283.cpp
Author: 1l6suj7
DateTime: 2024-01-25 17:20:30
Description:
*/
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lp(i, j, n) for(int i = j; i <= n; ++i)
#define dlp(i, n, j) for(int i = n; i >= j; --i)
#define mst(n, v) memset(n, v, sizeof(n))
#define mcy(n, v) memcpy(n, v, sizeof(v))
#define INF 1e18
#define MAX4 0x3f3f3f3f
#define MAX8 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define pll pair<ll, ll>
#define co(x) cerr << (x) << ' '
#define cod(x) cerr << (x) << endl
#define fi first
#define se second
#define eps 1e-8
#define lc(x) ((x) << 1)
#define rc(x) ((x) << 1 ^ 1)
#define pb(x) emplace_back(x)
using namespace std;
const int N = 500010;
namespace Trie {
const int S = 100 * N;
const int M = 33;
int ch[S][2], sum[S], tot;
int ins(int x, int id, ll val) {
int nw = ++tot; ch[nw][0] = ch[x][0], ch[nw][1] = ch[x][1], sum[nw] = sum[x];
if(!id) return ++sum[nw], nw;
ch[nw][(val >> id - 1) & 1] = ins(ch[nw][(val >> id - 1) & 1], id - 1, val);
sum[nw] = sum[ch[nw][0]] + sum[ch[nw][1]];
return nw;
}
ll kth(int x, int id, ll rk, ll val, ll res) {
if(!id) return res ^ val;
int t = (val >> id - 1) & 1;
if(rk <= sum[ch[x][t ^ 1]]) return kth(ch[x][t ^ 1], id - 1, rk, val, res + (t ^ 1) * (1ll << id - 1));
else return kth(ch[x][t], id - 1, rk - sum[ch[x][t ^ 1]], val, res + t * (1ll << id - 1));
}
} using namespace Trie;
struct node {
int p, k; ll val;
bool operator < (const node & t) const { return val < t.val; }
};
int n, K, rt[N];
ll ans, a[N];
priority_queue<node> q;
signed main() {
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
#ifndef READ
ios::sync_with_stdio(false);
cin.tie(0);
#endif
cin >> n >> K;
rt[0] = ins(rt[0], M, 0);
lp(i, 1, n) cin >> a[i], a[i] ^= a[i - 1], rt[i] = ins(rt[i - 1], M, a[i]);
lp(i, 1, n) {
ll v = kth(rt[i - 1], M, 1, a[i], 0);
// cod(v);
q.push({ i, 1, v });
}
int cnt = 0;
while(cnt < K) {
node fr = q.top(); q.pop();
ans += fr.val, ++cnt;
ll v = kth(rt[fr.p - 1], M, fr.k + 1, a[fr.p], 0);
q.push({ fr.p, fr.k + 1, v });
}
cout << ans << endl;
return 0;
}
如果 k 很大呢?
例如这道题:[CF241B] Friends
给出长度为 n n n 的序列 a i a_i ai,求前 m m m 大两个元素异或值之和。
m m m 达到了 1 0 9 10^9 109,如果用上面的方法,肯定会超时。
我们可以先二分出第 m m m 大异或值 m i d mid mid,然后对于每个 a i a_i ai,在可持久化 01-trie 上二分出在 a i a_i ai 左边,与 a i a_i ai 异或不小于 m i d mid mid 的个数,然后根据它和 m m m 的大小关系调整二分边界。
现在我们二分出第 m m m 大异或值了,于是要计算不小于这个异或值的异或值之和。同样是对于每个 a i a_i ai,在 01-trie 上二分,把计算异或值不小于 m i d mid mid 的个数改为计算异或值之和。
具体地,需要维护出每个子树个数和,以及每一位为 1 的个数和,然后算异或值之和时每一位分开计算。
时间复杂度 O ( n log n log V ) O(n\log n\log V) O(nlognlogV), V V V 为 a i a_i ai 的值域上界。
其实还有更优的做法但我没看懂。
Code:
/*
Program: CF241B.cpp
Author: 1l6suj7
DateTime: 2024-01-26 14:36:49
Description:
*/
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lp(i, j, n) for(int i = j; i <= n; ++i)
#define dlp(i, n, j) for(int i = n; i >= j; --i)
#define mst(n, v) memset(n, v, sizeof(n))
#define mcy(n, v) memcpy(n, v, sizeof(v))
#define INF 1e18
#define MAX4 0x3f3f3f3f
#define MAX8 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define pll pair<ll, ll>
#define co(x) cerr << (x) << ' '
#define cod(x) cerr << (x) << endl
#define fi first
#define se second
#define eps 1e-8
#define lc(x) ((x) << 1)
#define rc(x) ((x) << 1 ^ 1)
#define pb(x) emplace_back(x)
using namespace std;
const int N = 50010;
const ll MOD = 1e9 + 7;
int n, m, a[N];
namespace Trie {
const int S = 35 * N;
const int M = 31;
int ch[S][2], tot, cnt[S], num[S][M + 1];
int ins(int x, ll val, int id = M) {
int nw = ++tot; ch[nw][0] = ch[x][0], ch[nw][1] = ch[x][1], cnt[nw] = cnt[x];
lp(i, 0, M - 1) num[nw][i] = num[x][i];
++cnt[nw]; lp(i, 0, M - 1) if((val >> i) & 1) ++num[nw][i];
if(!id) return nw;
ch[nw][(val >> id - 1) & 1] = ins(ch[nw][(val >> id - 1) & 1], val, id - 1);
return nw;
}
int gcnt(int x, ll val, ll xval, int id = M) {
if(!id) return cnt[x];
int t = (val >> id - 1) & 1, xt = (xval >> id - 1) & 1;
if(xt == 0) return gcnt(ch[x][t], val, xval, id - 1) + cnt[ch[x][t ^ 1]];
return gcnt(ch[x][t ^ 1], val, xval, id - 1);
}
ll gsum(int x, ll val, ll xval, int id = M) {
if(!id) return 1ll * cnt[x] * xval % MOD;
int t = (val >> id - 1) & 1, xt = (xval >> id - 1) & 1;
if(xt == 0) {
ll res = 0;
lp(i, 0, M - 1) {
if((val >> i) & 1) res += 1ll * (cnt[ch[x][t ^ 1]] - num[ch[x][t ^ 1]][i]) << i;
else res += 1ll * num[ch[x][t ^ 1]][i] << i;
}
return gsum(ch[x][t], val, xval, id - 1) + res;
}
return gsum(ch[x][t ^ 1], val, xval, id - 1);
}
} using namespace Trie;
int rt[N];
signed main() {
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
#ifndef READ
ios::sync_with_stdio(false);
cin.tie(0);
#endif
cin >> n >> m;
lp(i, 1, n) cin >> a[i], rt[i] = ins(rt[i - 1], a[i]);
if(m == 0) return cout << 0 << endl, 0;
ll l = 0, r = 3e9;
while(l < r) {
ll mid = 1ll * l + r + 1 >> 1;
ll t = 0;
lp(i, 1, n) t += gcnt(rt[i - 1], a[i], mid);
if(t >= m) l = mid;
else r = mid - 1;
}
ll t = 0;
lp(i, 1, n) t += gcnt(rt[i - 1], a[i], l);
ll ans = -(t - m) * l;
lp(i, 1, n) ans += gsum(rt[i - 1], a[i], l);
cout << ans % MOD << endl;
return 0;
}