我还是太弱了,第一次碰主席树带删除的操作。首先我们按题目要求建一棵主席树(主席树只会是记前缀和),然后像COT那样用dfs序来维护,接着删除操作只是在树状数组上进行,树状数组只是记录每个点在原值的基础上加或减了多少。我们可以开个数组把子树都记下来,然后在二分的时候一边二分一边向下走就可以了。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
#define N 100010
#define M 8000010
int w[N],a[N],b[N],c[N],san[N],A[N],B[N],C[N],vis[N],lca[N],lft[N],rht[N],first[N],fa[N],f[N],root[N];
int sum[M],ls[M],rs[M];
int n,m,tot,cnt,ne,siz,tim,numa,numb;
vector<pair<int,int> > ask[N];
struct edge{int y,next;}e[N << 1];
void addedge(int x,int y){
e[++ cnt].next = first[x];e[cnt].y = y;first[x] = cnt;
e[++ cnt].next = first[y];e[cnt].y = x;first[y] = cnt;
}
char ch;
void read(int &x){
while (ch = getchar(),ch < '0' || ch > '9');
x = ch - '0';
while (ch = getchar(),ch >= '0' && ch <= '9') x = x * 10 + ch - '0';
}
void update(int x,int &rt,int l,int r,int y,int k)
{
int mid=(l+r)>>1;
if(rt==0)
{
siz++;
rt=siz;
ls[rt]=ls[x];
rs[rt]=rs[x];
sum[rt]=sum[x];
if(y<=mid) ls[rt]=0;
else rs[rt]=0;
}
sum[rt]+=k;
if(l==r) return;
if(y<=mid) update(ls[x],ls[rt],l,mid,y,k);
else update(rs[x],rs[rt],mid+1,r,y,k);
}
int find(int x)
{
return x == fa[x] ? x : fa[x] = find(fa[x]);
}
void dfs(int x)
{
update(root[f[x]],root[x],1,tot,w[x],1);
tim++;
lft[x]=tim;fa[x]=x;
for(int i = first[x];i;i = e[i].next)
if(e[i].y != f[x])
{
f[e[i].y] = x;
dfs(e[i].y);
fa[e[i].y] = x;
}
rht[x] = tim;
vis[x] = 1;
int t = ask[x].size();
for (int i = 0;i < t;i ++)
if (vis[ask[x][i].first])
{
lca[ask[x][i].second] = find(ask[x][i].first);
}
}
int lowbit(int x)
{
return x&(-x);
}
void getbit(int x,int p)
{
if(p)
{
numa++;a[numa]=root[x];
x = lft[x];
while(x)
{
numa++;a[numa]=c[x];
x-=lowbit(x);
}
}
else
{
numb++;b[numb]=root[x];
x = lft[x];
while(x)
{
numb++;b[numb]=c[x];
x-=lowbit(x);
}
}
}
void addup(int x,int pos,int val)
{
while(x<=n)
{
int tmp=0;
update(c[x],tmp,1,tot,pos,val);
c[x]=tmp;
x+=lowbit(x);
}
}
int query(int k)
{
int t,tt,l = 1,r = tot;
while (l < r)
{
t = 0,tt = 0;
for (int i = 1;i <= numa;i ++) t += sum[rs[a[i]]],tt += sum[a[i]];
for (int i = 1;i <= numb;i ++) t -= sum[rs[b[i]]],tt -= sum[b[i]];
if (tt < k) return -1;
if (k <= t)
{
for (int i = 1;i <= numa;i ++) a[i] = rs[a[i]];
for (int i = 1;i <= numb;i ++) b[i] = rs[b[i]];
l = (l + r >> 1) + 1;
}
else
{
for (int i = 1;i <= numa;i ++) a[i] = ls[a[i]];
for (int i = 1;i <= numb;i ++) b[i] = ls[b[i]];
r = l + r >> 1;
k -= t;
}
}
return l;
}
int main()
{
read(n);read(m);
int x,y,mid;
for (int i = 1;i <= n;i ++) read(w[i]),san[++ tot] = w[i];
for (int i = 1;i < n;i ++) read(x),read(y),addedge(x,y);
for (int i = 1;i <= m;i ++)
{
read(A[i]);read(B[i]);read(C[i]);
if (!A[i])san[++ tot] = C[i];
else ask[B[i]].push_back(make_pair(C[i],i)),ask[C[i]].push_back(make_pair(B[i],i));
}
sort(san + 1,san + 1 + tot);
tot = unique(san + 1,san + 1 + tot) - san - 1;
for (int i = 1;i <= n;i ++) w[i] = lower_bound(san + 1,san + tot + 1,w[i]) - san;
dfs((n + 1) / 2);
for (int i = 1;i <= m;i ++)
{
if (A[i])
{
x = B[i],y = C[i],mid = lca[i];
numa = numb = 0;
getbit(x,1),getbit(y,1);
getbit(mid,0),getbit(f[mid],0);
x = query(A[i]);
if(x < 0) puts("invalid request!");
else printf("%d\n",san[x]);
}
else
{
x = B[i],y = C[i];
y = lower_bound(san + 1,san + tot + 1,y) - san;
addup(lft[x],w[x],-1);
addup(rht[x]+1,w[x],1);
addup(lft[x],y,1);
addup(rht[x]+1,y,-1);
w[x] = y;
}
}
return 0;
}