这题之前用指针的splay来写,各种无力各种爆内存,真是受不了
今天改了一下splay的写法,一下就AC了
更新的方法就是把线段树上的线段包含这个点的都更新(先删掉,再添加)(log^2n)
查第k小就是二分比t小的数的个数,寻找刚好大于等于k的那个数(log^3n)
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define ls c[x][0]
#define rs c[x][1]
using namespace std;
const int maxn = 6e4+10;
const int maxe = 15*maxn;
const int inf = 1e9+7;
int p[maxe],c[maxe][2];
int key[maxe],sz[maxe],num[maxe];
int a[maxn];
int pool[100],tot1,tot2;
bool getlr(int r){return c[p[r]][1]==r;}
int link(int r,int w,int x){c[r][w]=x;if(x)p[x]=r;return r;}
void NewNode(int &r,int v,int per)
{
if(tot2) r = pool[tot2--];
else r = ++tot1;
p[r] = per;
sz[r] = num[r] = 1;
key[r] = v;
c[r][0] = c[r][1] = 0;
}
struct Splay
{
int root;
void init() {root = 0;}
void pushup(int x)
{
sz[x] = sz[ls] + sz[rs] + num[x];
}
void rot(int x)
{
int z = p[p[x]], o = getlr(x);
link(x,!o,link(p[x],o,c[x][!o]));
pushup(p[x]);
if(z) link(z,c[z][1]==p[x],x);
else p[x]=0, root=x;
}
void splay(int x,int tar)
{
while(p[x]!=tar)
{
if(p[p[x]]!=tar)getlr(x)==getlr(p[x])?rot(p[x]):rot(x);
rot(x);
}
pushup(x);
}
int Find(int k)
{
int x = root;
while(x && key[x] != k) x = c[x][key[x] < k];
if(x) splay(x, 0);
return x;
}
void Insert(int k)
{
if(!root) {NewNode(root, k, 0);return;}
if(Find(k)) {++num[root]; pushup(root); return;}
int x = root, x1;
while(x) {x1 = x; x = c[x][key[x] < k];}
NewNode(x, k, 0);
link(x1, key[x1] < k, x);
splay(x, 0);
}
int prev()
{
int x = c[root][0];
while(c[x][1]) x = c[x][1];
return x;
}
void earse(int x)
{
if(!x) return;
pool[++tot2] = x;
}
void Delroot()
{
if(!root) return;
int x = root, q;
if(!c[root][0]) {root = c[root][1]; p[root] = 0;}
else
{
q = prev();
splay(q, root);
link(q, 1, c[root][1]);
earse(root);
root = q; p[q] = 0;
pushup(root);
}
}
void Del(int x)
{
if(!Find(x)) return;
if(num[root] > 1) {--num[root]; pushup(root); return;}
Delroot();
}
int cntLess(int k)
{
int res = 0;
int x = root;
while(x)
{
if(key[x] > k)
x = ls;
else
{
res += num[x] + sz[c[x][0]];
x = rs;
}
}
return res;
}
}Sp[maxn<<2];
int getnumLess(int L,int R,int x,int l,int r,int rt)
{
if(L<=l && R>=r)
return Sp[rt].cntLess(x);
int m = (l+r)/2, res = 0;
if(L <= m) res += getnumLess(L, R, x, lson);
if(R > m) res += getnumLess(L, R, x, rson);
return res;
}
int Findkth(int L,int R,int N,int k)
{
int l = 0, r = inf, ans = 0;
while(l <= r)
{
int m = (l+r)/2;
int cnt = getnumLess(L, R, m, 1, N, 1);
if(cnt >= k)
{
r = m-1;
ans = m;
}
else l = m+1;
}
return ans;
}
void build(int l,int r,int rt)
{
Sp[rt].init();
for(int i=l;i<=r;i++)
Sp[rt].Insert(a[i]);
if(l==r) return;
int m = (l+r)/2;
build(lson);
build(rson);
}
void update(int i,int x,int y,int l,int r,int rt)
{
Sp[rt].Del(x);
Sp[rt].Insert(y);
if(l == r) return;
int m = (l+r)/2;
if(i<=m) update(i, x, y, lson);
else update(i, x, y, rson);
}
void init()
{
tot1 = tot2 = 0;
}
int main()
{
int n,m,ca,l,r,k;
char op[10];
scanf("%d",&ca);
while(ca--)
{
scanf("%d%d",&n,&m);
init();
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,n,1);
while(m--)
{
scanf("%s",op);
if(op[0]=='Q')
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",Findkth(l,r,n,k));
}
else if(op[0]=='C')
{
scanf("%d%d",&l,&k);
update(l,a[l],k,1,n,1);
a[l]=k;
}
}
}
return 0;
}