每种信仰存一个线段树
动态开点线段树就是点 现用现开 所以要存左右儿子
一次询问最多新建logn节点
空间复杂度 m*logn
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=100000+5;
int n,w[N],c[N];
int num,last[N],nxt[2*N],ver[2*N];
inline void add(int x,int y) {nxt[++num]=last[x]; last[x]=num; ver[num]=y;}
int fa[N],sum[N],deep[N],son[N];
void build(int x)
{sum[x]=1;
for(int i=last[x];i;i=nxt[i])
{int y=ver[i];
if(y==fa[x]) return;
fa[y]=x; deep[y]=deep[x]+1;
build(y);
sum[x]+=sum[y];
if(sum[son[x]]<sum[y]) son[x]=y;
}
}
int id,a[N],ord[N],top[N];
void dfs(int x)
{a[++id]=x; ord[x]=id;
if(x==son[fa[x]]) top[x]=top[fa[x]];
else top[x]=x;
if(son[x]) dfs(son[x]);
for(int i=last[x];i;i=nxt[i])
{int y=ver[i];
if(y==fa[x] || y==son[x]) continue;
dfs(y);
}
}
struct point{int ls,rs,maxx,sum;}t[20*N]; int tot,root[N];
void change(int &i,int l,int r,int pos,int x)
{if(!i) i=++tot;
if(l==r) {t[i].sum=t[i].maxx=x; return;}
int mid=(l+r)/2;
if(pos<=mid) change(t[i].ls,l,mid,pos,x);
else change(t[i].rs,mid+1,r,pos,x);
t[i].sum=t[t[i].ls].sum+t[t[i].rs].sum;
t[i].maxx=max(t[t[i].ls].maxx,t[t[i].rs].maxx);
}
point ask(int &i,int l,int r,int p,int q)
{if(!i) return t[tot+1];
if(p<=l && r<=q) return t[i];
int mid=l+r>>1;
if(q<=mid) return ask(t[i].ls,l,mid,p,q);
if(p> mid) return ask(t[i].rs,mid+1,r,p,q);
point r1=ask(t[i].ls,l,mid,p,q),r2=ask(t[i].rs,mid+1,r,p,q);
r1.sum=r1.sum+r2.sum;
r1.maxx=max(r1.maxx,r2.maxx);
return r1;
}
inline point query(int x,int y,int z)
{point re; re.sum=re.maxx=0;
while(top[x]!=top[y])
{if(deep[top[x]]<deep[top[y]]) swap(x,y);
point k=ask(z,1,n,ord[top[x]],ord[x]);
re.sum+=k.sum;
re.maxx=max(re.maxx,k.maxx);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
point k=ask(z,1,n,ord[y],ord[x]);
re.sum+=k.sum;
re.maxx=max(re.maxx,k.maxx);
return re;
}
int main()
{
int q,x,y;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++) scanf("%d%d",&w[i],&c[i]);
for(int i=1;i<n;i++) {scanf("%d%d",&x,&y); add(x,y); add(y,x);}
deep[1]=1; build(1); dfs(1);
for(int i=1;i<=n;i++) change(root[c[i]],1,n,ord[i],w[i]);
while(q--)
{ char op[5]; scanf("%s %d%d",op,&x,&y);
if(op[1]=='C')
{ change(root[c[x]],1,n,ord[x],0);
change(root[y],1,n,ord[x],w[x]);
c[x]=y;
}
else if(op[1]=='W')
{ change(root[c[x]],1,n,ord[x],y);
w[x]=y;
}
else if(op[1]=='S')
{point ans=query(x,y,root[c[y]]);
printf("%d\n",ans.sum);
}
else if(op[1]=='M')
{point ans=query(x,y,root[c[y]]);
printf("%d\n",ans.maxx);
}
}
return 0;
}