传送门
树剖+线段树套平衡树
看到如此长的时限,我就放心大胆的开始(开车)了。
打代码真爽
具体实现:
1.二分答案,
2.在被轻重链剖分的树上跑线段树,
3.在每一个节点上维护平衡树用来查询区间排名。
复杂度分析:
1.二分一只log
2.树剖一只log
3.线段树一只log
4.平衡树一只log
然后就实现了Nlog^4N的做法。
四只log,四只log,跑得快,跑得快。
其实和N^2是没有任何区别的。
代码贼长。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define N 100005
#define M 5000005
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
for (;ch<48||ch>57;ch=getchar()) if (ch=='-') f=-1;
for (;ch>47&&ch<58;ch=getchar()) x=x*10-48+ch;
return x;
}
int n,q,cnt,tot,sz,tmp,x,y;
int bin[20],T[N],hash[N],f[N],a[N],b[N];
int fa[N][17],son[N],dep[N],top[N],rt[N*3];
int w[M],v[M],s[M],rnd[M],ls[M],rs[M];
int vis[N],head[N],pl[N];
struct edge{int to,next;}e[N*2];
void add(int x,int y){
e[++tot]=(edge){y,head[x]};
head[x]=tot;
e[++tot]=(edge){x,head[y]};
head[y]=tot;
}
int find(int x){
int l=1,r=tot;
while (l<=r){
int mid=(l+r)/2;
if (hash[mid]==x) return mid;
if (hash[mid]<x) l=mid+1; else r=mid-1;
}
return l;
}
void dfs1(int x){
son[x]=vis[x]=1;
for (int i=1;bin[i]<=dep[x]&&i<=16;i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=head[x];i;i=e[i].next)
if (!vis[e[i].to]){
dep[e[i].to]=dep[x]+1;
fa[e[i].to][0]=x;
dfs1(e[i].to);
son[x]+=son[e[i].to];
}
}
void dfs2(int x,int tp){
pl[x]=++cnt; top[x]=tp;
int k=0;
for (int i=head[x];i;i=e[i].next)
if (dep[e[i].to]>dep[x]&&son[e[i].to]>son[k])
k=e[i].to;
if (k) dfs2(k,tp);
for (int i=head[x];i;i=e[i].next)
if (dep[e[i].to]>dep[x]&&e[i].to!=k)
dfs2(e[i].to,e[i].to);
}
int lca(int x,int y){
if (dep[x]<dep[y]) swap(x,y);
int tmp=dep[x]-dep[y];
for (int i=0;i<=16;i++)
if (bin[i]&tmp) x=fa[x][i];
for (int i=16;i>=0;i--)
if (fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return (x==y)?x:fa[x][0];
}
void update(int k){
s[k]=s[ls[k]]+s[rs[k]]+w[k];
}
void rturn(int &k){
int t=ls[k]; ls[k]=rs[t]; rs[t]=k;
update(k); update(t); k=t;
}
void lturn(int &k){
int t=rs[k]; rs[k]=ls[t]; ls[t]=k;
update(k); update(t); k=t;
}
void insert(int &k,int num){
if (!k){
k=++sz; rnd[k]=rand();
w[k]=s[k]=1; v[k]=num;
return;
}
s[k]++;
if (num==v[k]){
w[k]++; return;
}
if (num<v[k]){
insert(ls[k],num);
if (rnd[ls[k]]<rnd[k]) rturn(k);
}
else{
insert(rs[k],num);
if (rnd[rs[k]]<rnd[k]) lturn(k);
}
}
void del(int &k,int num){
if (!k) return;
if (num==v[k]){
if (w[k]>1){
w[k]--; s[k]--;
return;
}
if (ls[k]*rs[k]==0) k=ls[k]+rs[k];
else if (rnd[ls[k]]<rnd[rs[k]]){
rturn(k); del(k,num);
}
else{
lturn(k); del(k,num);
}
}
else if (num<v[k]){
del(ls[k],num); s[k]--;
}
else{
del(rs[k],num); s[k]--;
}
}
void askrk(int k,int num){
if (!k) return;
if (num==v[k]){tmp+=s[rs[k]]; return;}
if (num<v[k]){
tmp+=s[rs[k]]+w[k];
askrk(ls[k],num);
}
else askrk(rs[k],num);
}
void change(int k,int l,int r,int pos,int x,int y){
del(rt[k],x); insert(rt[k],y);
if (l==r) return;
int mid=(l+r)/2;
if (pos<=mid) change(k*2,l,mid,pos,x,y);
else change(k*2+1,mid+1,r,pos,x,y);
}
void ask(int k,int l,int r,int x,int y,int num){
if (x==l&&y==r){askrk(rt[k],num); return;}
int mid=(l+r)/2;
if (y<=mid) ask(k*2,l,mid,x,y,num);
else if (x>mid) ask(k*2+1,mid+1,r,x,y,num);
else ask(k*2,l,mid,x,mid,num),ask(k*2+1,mid+1,r,mid+1,y,num);
}
void getrk(int x,int f,int num){
while (top[x]!=top[f]){
ask(1,1,n,pl[top[x]],pl[x],num);
x=fa[top[x]][0];
}
ask(1,1,n,pl[f],pl[x],num);
}
void solve(int x,int y,int rk){
int t=lca(x,y),ans=-1;
tmp=0; getrk(y,t,0); getrk(x,t,0);
if (tmp-1<rk){
printf("invalid request!\n");
return;
}
int l=1,r=tot;
while (l<=r){
int mid=(l+r)/2;
tmp=0;
getrk(x,t,mid);
getrk(y,t,mid);
if (T[t]>mid) tmp--;
if (tmp<rk) r=mid-1,ans=mid;
else l=mid+1;
}
printf("%d\n",hash[ans]);
}
int main(){
bin[0]=1;
for (int i=1;i<=16;i++) bin[i]=bin[i-1]*2;
n=read(); q=read();
for (int i=1;i<=n;i++)
hash[i]=T[i]=read();
for (int i=1;i<n;i++){
x=read(); y=read();
add(x,y);
}
tot=n;
dfs1(1); dfs2(1,1);
for (int i=1;i<=q;i++){
scanf("%d%d%d",&f[i],&a[i],&b[i]);
if (!f[i]) hash[++tot]=b[i];
}
sort(hash+1,hash+tot+1);
int top=1;
for (int i=2;i<=tot;i++)
if (hash[i]!=hash[i-1]) hash[++top]=hash[i];
tot=top;
for (int i=1;i<=n;i++) T[i]=find(T[i]);
for (int i=1;i<=q;i++) if (!f[i]) b[i]=find(b[i]);
for (int i=1;i<=n;i++) change(1,1,n,pl[i],0,T[i]);
for (int i=1;i<=q;i++)
if (!f[i]){
change(1,1,n,pl[a[i]],T[a[i]],b[i]);
T[a[i]]=b[i];
}
else solve(a[i],b[i],f[i]);
}