题面:
题解:
树链剖分+动态开点线段树。
对于每一个宗教值 ci 建立一棵线段树,查询x–y时,在c [ x ] 这棵线段树上查询。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#define ll long long
#define llu unsigned ll
#define pr make_pair
#define pb push_back
#define y1 yy
using namespace std;
const int maxn=100100;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int f[maxn],d[maxn],si[maxn],son[maxn],rk[maxn];
int top[maxn],id[maxn],w[maxn],c[maxn];
int tot=1,cnt=0,tcnt=0,n,q;
char str[16];
void add(int x,int y)
{
ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int fa)
{
int max_son=0;
si[x]=1;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
f[y]=x;
d[y]=d[x]+1;
dfs1(y,x);
si[x]+=si[y];
if(si[y]>max_son) max_son=si[y],son[x]=y;
}
}
void dfs2(int x,int t)
{
top[x]=t;
id[x]=++cnt;
rk[cnt]=x;
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y!=f[x]&&y!=son[x])
dfs2(y,y);
}
}
struct node
{
int lc,rc;
int maxx,sum;
}t[maxn*20*3];
int root[maxn];
int newnode(void)
{
++tcnt;
t[tcnt].lc=t[tcnt].rc=t[tcnt].maxx=t[tcnt].sum=0;
return tcnt;
}
int _insert(int p,int l,int r,int pos,int val)
{
if(!p) p=newnode();
if(l==r)
{
t[p].maxx=val;
t[p].sum=val;
return p;
}
int mid=(l+r)>>1;
if(pos<=mid) t[p].lc=_insert(t[p].lc,l,mid,pos,val);
else t[p].rc=_insert(t[p].rc,mid+1,r,pos,val);
t[p].maxx=max(t[t[p].lc].maxx,t[t[p].rc].maxx);
t[p].sum=t[t[p].lc].sum+t[t[p].rc].sum;
return p;
}
int ask_max(int p,int L,int R,int l,int r)
{
if(L>=l&&R<=r) return t[p].maxx;
int maxx=0;
int mid=(L+R)>>1;
if(l<=mid) maxx=max(maxx,ask_max(t[p].lc,L,mid,l,r));
if(r>mid) maxx=max(maxx,ask_max(t[p].rc,mid+1,R,l,r));
return maxx;
}
int ask_sum(int p,int L,int R,int l,int r)
{
if(L>=l&&R<=r) return t[p].sum;
int ans=0;
int mid=(L+R)>>1;
if(l<=mid) ans=ans+ask_sum(t[p].lc,L,mid,l,r);
if(r>mid) ans=ans+ask_sum(t[p].rc,mid+1,R,l,r);
return ans;
}
int ask_road_max(int x,int y,int ci)
{
int maxx=0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
swap(x,y);
maxx=max(maxx,ask_max(root[ci],1,cnt,id[top[x]],id[x]));
x=f[top[x]];
}
if(id[x]>id[y])
swap(x,y);
maxx=max(maxx,ask_max(root[ci],1,cnt,id[x],id[y]));
return maxx;
}
int ask_road_sum(int x,int y,int ci)
{
int ans=0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
swap(x,y);
ans+=ask_sum(root[ci],1,cnt,id[top[x]],id[x]);
x=f[top[x]];
}
if(id[x]>id[y])
swap(x,y);
ans+=ask_sum(root[ci],1,cnt,id[x],id[y]);
return ans;
}
int main(void)
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)
scanf("%d%d",&w[i],&c[i]);
int x,y;
for(int i=1;i<n;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=n;i++)
root[c[i]]=_insert(root[c[i]],1,cnt,id[i],w[i]);
for(int i=1;i<=q;i++)
{
scanf("%s",str);
if(str[1]=='C')
{
scanf("%d%d",&x,&y);
root[c[x]]=_insert(root[c[x]],1,cnt,id[x],0);
c[x]=y;
root[c[x]]=_insert(root[c[x]],1,cnt,id[x],w[x]);
}
else if(str[1]=='W')
{
scanf("%d%d",&x,&y);
w[x]=y;
root[c[x]]=_insert(root[c[x]],1,cnt,id[x],w[x]);
}
else if(str[1]=='S')
{
scanf("%d%d",&x,&y);
printf("%d\n",ask_road_sum(x,y,c[x]));
}
else
{
scanf("%d%d",&x,&y);
printf("%d\n",ask_road_max(x,y,c[x]));
}
}
return 0;
}