思路
每个节点表示数值,而不是表示区间
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
int n,q,a[N];
#define lc (p<<1)
#define rc (p<<1|1)
struct Node{
int l,r,sum;
}tr[N<<2];
void build(int p,int l,int r){
tr[p] = {l,r,0};
if(l == r){
return;
}
int mid = (l+r) >> 1;
build(lc,l,mid);
build(rc,mid+1,r);
}
void push_up(int p){
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void update(int p,int val,int v){
if(tr[p].l == tr[p].r){
tr[p].sum+=v;
return;
}
int mid = (tr[p].l + tr[p].r) >> 1;
if(val<=mid){
update(lc,val,v);
}else update(rc,val,v);
push_up(p);
}
int query(int p,int k){
if(tr[p].l == tr[p].r) return tr[p].l;
int mid = (tr[p].l + tr[p].r) >> 1;
if(tr[lc].sum>=k) return query(lc,k);
else return query(rc,k-tr[lc].sum);
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cin >> n >> q;
build(1,1,N);
for(int i=1;i<=n;++i){
cin >> a[i];
update(1,a[i],1);
}
int op,x,y,k;
while(q--){
cin >> op;
if(op == 1){
cin >> k;
cout << query(1,k) << '\n';
}else{
cin >> x >> y;
update(1,a[x],-1);
a[x] = y;
update(1,y,1);
}
}
return 0;
}
树状数组部分
当然也可以用树状数组来写
离散化版
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 3e5 + 10,M = 1e6+10;
int n,q,tr[N],a[N],mx;
vector<int> ve;
int lowbit(int x){
return x&-x;
}
void add(int u,int v){
for(int i=u;i<=ve.size();i+=lowbit(i)){
tr[i] += v;
}
}
int query(int x){
int ret = 0;
for(int i=x;i;i-=lowbit(i)){
ret += tr[i];
}
return ret;
}
//根据rank查值
int kth(int rank){
int idx = 0;
for(int i=20;i>=0;--i){
idx += (1<<i);
if(idx > ve.size() || tr[idx] >= rank){
idx -= (1<<i);
}
else{
rank -= tr[idx];
}
}
return ve[idx + 1 - 1];
}
int _get(int x){
return lower_bound(ve.begin(),ve.end(),x) - ve.begin() + 1;
}
struct Query{
int opt,x,y;
}ask[N];
void solve(){
cin >> n >> q;
for(int i=1;i<=n;++i){
cin >> a[i];
ve.push_back(a[i]);
}
int opt,x,y;
for(int i=1;i<=q;++i){
y = 0;
cin >> opt;
if(opt == 1) cin >> x;
else {
cin >> x >> y;
ve.push_back(y);
}
ask[i] = {opt,x,y};
}
sort(ve.begin(),ve.end());
ve.erase(unique(ve.begin(),ve.end()),ve.end());
for(int i=1;i<=n;++i){
add(_get(a[i]),1);
}
for(int i=1;i<=q;++i){
int opt = ask[i].opt,x=ask[i].x,y=ask[i].y;
if(opt == 1){
cout << kth(x) << '\n';
}else{
add(_get(a[x]),-1);
a[x] = y;
add(_get(y),1);
}
}
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
solve();
return 0;
}
非离散化版
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 3e5 + 10,M = 1e6+10;
int n,q,tr[M],a[N],mx;
int lowbit(int x){
return x&-x;
}
void add(int u,int v){
for(int i=u;i<=1e6;i+=lowbit(i)){
tr[i] += v;
}
}
int query(int x){
int ret = 0;
for(int i=x;i;i-=lowbit(i)){
ret += tr[i];
}
return ret;
}
//根据rank查值
int kth(int rank){
int idx = 0;
for(int i=20;i>=0;--i){
idx += (1<<i);
if(idx > 1e6 || tr[idx] >= rank){
idx -= (1<<i);
}
else{
rank -= tr[idx];
}
}
return idx + 1;
}
void solve(){
cin >> n >> q;
for(int i=1;i<=n;++i){
cin >> a[i];
add(a[i],1);
}
int opt,x,y;
for(int i=1;i<=q;++i){
cin >> opt;
if(opt == 1){
cin >> x;
cout << kth(x) << '\n';
}
else {
cin >> x >> y;
add(a[x],-1);
a[x] = y;
add(y,1);
}
}
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
solve();
return 0;
}