对每种宗教都建立线段树,建立过程类似主席树,在需要修改时动态建立线段树(初始时不建立),所以空间复杂度是
O(nlogn)
PS:为什么第二份会T啊
#include<iostream>
#include<cstdio>
#include<cstring>
#define M 6000005
#define inf (1<<30)
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,m,cnt,place,size;
int w[100005],c[100005],root[100005];
int fa[100005][17],deep[100005],pl[100005],belong[100005],son[100005];
int ls[M],rs[M],mx[M],sum[M];
struct data{int to,next;}e[200005];int head[100005];
void ins(int u,int v)
{
e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;
e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;
}
#define bl belong
#define dep deep
#define pos pl
#define f fa
void dfs1(int u)
{
son[u]=1;
for(int i=head[u];i;i=e[i].next) {
int id=e[i].to;
if(id==f[u][0])continue;
dep[id]=dep[u]+1;f[id][0]=u;
dfs1(id);
son[u]+=son[id];
}
}
void dfs2(int u,int chain)
{
pos[u]=++size;bl[u]=chain;int k=0;
for(int i=head[u];i;i=e[i].next) {
int id=e[i].to;
if(dep[id]>dep[u]&&son[id]>son[k]) k=id;
}if(k==0)return ;dfs2(k,chain);
for(int i=head[u];i;i=e[i].next) {
int id=e[i].to;
if(dep[id]>dep[u]&&id!=k) dfs2(id,id);
}
}
void update(int x)
{
mx[x]=max(mx[ls[x]],mx[rs[x]]);
sum[x]=sum[ls[x]]+sum[rs[x]];
}
void change(int &k,int l,int r,int x,int num)
{
if(!k)k=++size;
if(l==r){mx[k]=sum[k]=num;return;}
int mid=(l+r)>>1;
if(x<=mid)change(ls[k],l,mid,x,num);
else change(rs[k],mid+1,r,x,num);
update(k);
}
int askmx(int k,int l,int r,int x,int y)
{
if(!k)return 0;
if(l==x&&y==r)return mx[k];
int mid=(l+r)>>1;
if(y<=mid)return askmx(ls[k],l,mid,x,y);
else if(x>mid)return askmx(rs[k],mid+1,r,x,y);
else return max(askmx(ls[k],l,mid,x,mid),askmx(rs[k],mid+1,r,mid+1,y));
}
int asksum(int k,int l,int r,int x,int y)
{
if(!k)return 0;
if(l==x&&y==r)return sum[k];
int mid=(l+r)>>1;
if(y<=mid)return asksum(ls[k],l,mid,x,y);
else if(x>mid)return asksum(rs[k],mid+1,r,x,y);
else return asksum(ls[k],l,mid,x,mid)+asksum(rs[k],mid+1,r,mid+1,y);
}
int solvesum(int x,int y,int c)
{
int sum=0;
while(bl[x]!=bl[y]) {
if(dep[bl[x]]<dep[bl[y]])swap(x,y);
sum+=asksum(root[c],1,n,pos[bl[x]],pos[x]);
x=f[bl[x]][0];
}if(dep[x]>dep[y])swap(x,y);
sum+=asksum(root[c],1,n,pos[x],pos[y]);
return sum;
}
int solvemx(int x,int y,int c)
{
int mx=-inf;
while(bl[x]!=bl[y]) {
if(dep[bl[x]]<dep[bl[y]])swap(x,y);
mx=max(mx,askmx(root[c],1,n,pos[bl[x]],pos[x]));
x=f[bl[x]][0];
}if(dep[x]>dep[y])swap(x,y);
mx=max(mx,askmx(root[c],1,n,pos[x],pos[y]));
return mx;
}
void solve()
{
for(int i=1;i<=n;i++)change(root[c[i]],1,n,pl[i],w[i]);
for(int i=1;i<=m;i++){
char ch[5];scanf("%s",ch);int x=read(),y=read();
if(ch[0]=='C'){
if(ch[1]=='C'){
change(root[c[x]],1,n,pl[x],0);c[x]=y;
change(root[c[x]],1,n,pl[x],w[x]);
}else {change(root[c[x]],1,n,pl[x],y);w[x]=y;}
}else{
if(ch[1]=='S')printf("%d\n",solvesum(x,y,c[x]));
else printf("%d\n",solvemx(x,y,c[x]));
}
}
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;i++)w[i]=read(),c[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
ins(u,v);
}
dfs1(1);dfs2(1,1);
solve();
return 0;
}
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=1e5+10,M=6e6+10,inf=(1<<30);
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int root[N],ls[M],rs[M],sum[M],mx[N],
dep[N],s[N],f[N],pos[N],bl[N],last[N],c[N],w[N],
size=0,len=0,cnt=0,n,q;
struct Edge{int to,next;Edge(int to=0,int next=0):to(to),next(next){}}e[N<<2];
void add_edge(int u,int v){e[++len]=Edge(v,last[u]);last[u]=len;}
void dfs1(int u)
{
s[u]=1;
for(int i=last[u];i;i=e[i].next) {
int id=e[i].to;
if(id==f[u])continue;
dep[id]=dep[u]+1;f[id]=u;
dfs1(id);
s[u]+=s[id];
}
}
void dfs2(int u,int chain)
{
pos[u]=++size;bl[u]=chain;int k=0;
for(int i=last[u];i;i=e[i].next) {
int id=e[i].to;
if(dep[id]>dep[u]&&s[id]>s[k]) k=id;
}if(k==0)return ;dfs2(k,chain);
for(int i=last[u];i;i=e[i].next) {
int id=e[i].to;
if(dep[id]>dep[u]&&id!=k) dfs2(id,id);
}
}
void update(int &k,int l,int r,int pos,int val)
{
if(!k)k=++cnt;
if(l==r){sum[k]=mx[k]=val;return;}
int mid=(l+r)>>1;
if(pos<=mid)update(ls[k],l,mid,pos,val);
else update(rs[k],mid+1,r,pos,val);
sum[k]=sum[ls[k]]+sum[rs[k]];
mx[k]=max(mx[ls[k]],mx[rs[k]]);
}
int query_sum(int k,int l,int r,int ql,int qr)
{
if(!k)return 0;
if(ql==l&&qr==r){return sum[k];}
int mid=(l+r)>>1;
if(qr<=mid)return query_sum(ls[k],l,mid,ql,qr);
else if(ql>mid)return query_sum(rs[k],mid+1,r,ql,qr);
else return (query_sum(ls[k],l,mid,ql,mid)+query_sum(rs[k],mid+1,r,mid+1,qr));
}
int query_mx(int k,int l,int r,int ql,int qr)
{
if(!k)return 0;
if(ql==l&&qr==r){return mx[k];}
int mid=(l+r)>>1;
if(qr<=mid)return query_mx(ls[k],l,mid,ql,qr);
else if(ql>mid)return query_mx(rs[k],mid+1,r,ql,qr);
else return max(query_mx(ls[k],l,mid,ql,mid),query_mx(rs[k],mid+1,r,mid+1,qr));
}
int solvesum(int x,int y,int c)
{
int sum=0;
while(bl[x]!=bl[y]) {
if(dep[bl[x]]<dep[bl[y]])swap(x,y);
sum+=query_sum(root[c],1,n,pos[bl[x]],pos[x]);
x=f[bl[x]];
}if(dep[x]>dep[y])swap(x,y);
sum+=query_sum(root[c],1,n,pos[x],pos[y]);
return sum;
}
int solvemx(int x,int y,int c)
{
int mx=-inf;
while(bl[x]!=bl[y]) {
if(dep[bl[x]]<dep[bl[y]])swap(x,y);
mx=max(mx,query_mx(root[c],1,n,pos[bl[x]],pos[x]));
x=f[bl[x]];
}if(dep[x]>dep[y])swap(x,y);
mx=max(mx,query_mx(root[c],1,n,pos[x],pos[y]));
return mx;
}
void solve()
{
fo(i,1,n) update(root[c[i]],1,n,pos[i],w[i]);
for(int i=1;i<=q;i++) {
char ch[7];int x,y;
scanf("%s",ch);x=read();y=read();
if(ch[0]=='C') {
if(ch[1]=='C') {
update(root[c[x]],1,n,pos[x],0);c[x]=y;
update(root[c[x]],1,n,pos[x],w[x]);
}else update(root[c[x]],1,n,pos[x],y),w[x]=y;
}else {
if(ch[1]=='S')printf("%d\n",solvesum(x,y,c[x]));
else printf("%d\n",solvemx(x,y,c[x]));
}
}
}
int main()
{
scanf("%d%d",&n,&q);
fo(i,1,n) {w[i]=read();c[i]=read();}
for(int i=1,from,to;i<n;i++) {
from=read();to=read();
add_edge(from,to);add_edge(to,from);
}
dfs1(1);dfs2(1,1);
solve();
return 0;
}