处理点的树链剖分。支持单点更新,树上求最值和树上求和。
注意的地方:
记录树上的点在线段树的位置的同时,也要记录线段树的某个位置是树上哪个点。因为线段树建树的时候是按照线性从左往右的,需要知道每个位置的点是哪个。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<cmath>
#include<map>
#include<string>
#include<set>
#include<queue>
#include <algorithm>
using namespace std;
const int maxn=30010;
const int INF=0x80000000;
int head[maxn],to[maxn*2],nxt[maxn*2],edge;
int tot;
int siz[maxn],dep[maxn],son[maxn],fa[maxn],top[maxn],id[maxn];
int rid[maxn],arr[maxn];
int n;
void init()
{
edge=0;
memset(head,-1,sizeof(head));
tot=0;
memset(son,-1,sizeof(son));
}
void addEdge(int u,int v)
{
to[edge]=v,nxt[edge]=head[u],head[u]=edge++;
to[edge]=u,nxt[edge]=head[v],head[v]=edge++;
}
void dfs1(int now,int f,int d)
{
siz[now]=1;
dep[now]=d;
fa[now]=f;
for(int i=head[now]; i!=-1; i=nxt[i])
{
int v=to[i];
if(v!=f)
{
dfs1(v,now,d+1);
if(son[now]==-1||siz[son[now]]<siz[v])
son[now]=v;
siz[now]+=siz[v];
}
}
}
void dfs2(int now,int tp)
{
id[now]=++tot;
rid[id[now]]=now;
top[now]=tp;
if(son[now]==-1)
return ;
dfs2(son[now],tp);
for(int i=head[now]; i!=-1; i=nxt[i])
{
int v=to[i];
if(v!=fa[now]&&v!=son[now])
dfs2(v,v);
}
}
int sum[4*maxn],maxi[4*maxn];
void push_up(int o)
{
sum[o]=sum[o<<1]+sum[o<<1|1];
maxi[o]=max(maxi[o<<1],maxi[o<<1|1]);
}
void build(int o,int l,int r)
{
if(l==r)
{
maxi[o]=sum[o]=arr[rid[l]];
}
else
{
int m=l+(r-l)/2;
build(o<<1,l,m);
build(o<<1|1,m+1,r);
push_up(o);
}
}
int queryMax(int o,int L,int R,int ql,int qr)
{
if(ql<=L&&R<=qr)
return maxi[o];
else
{
int M=L+(R-L)/2;
int res=INF;
if(ql<=M)
res=max(res,queryMax(o<<1,L,M,ql,qr));
if(M<qr)
res=max(res,queryMax(o<<1|1,M+1,R,ql,qr));
return res;
}
}
int querySum(int o,int L,int R,int ql,int qr)
{
if(ql<=L&&R<=qr)
return sum[o];
else
{
int M=L+(R-L)/2;
int res=0;
if(ql<=M)
res+=querySum(o<<1,L,M,ql,qr);
if(M<qr)
res+=querySum(o<<1|1,M+1,R,ql,qr);
return res;
}
}
void update(int o,int L,int R,int pos,int val)
{
if(L==R)
maxi[o]=sum[o]=val;
else
{
int M=L+(R-L)/2;
if(pos<=M)
update(o<<1,L,M,pos,val);
else
update(o<<1|1,M+1,R,pos,val);
push_up(o);
}
}
int getMax(int x,int y)
{
int res=INF;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
swap(x,y);
res=max(res,queryMax(1,1,n,id[top[y]],id[y]));
y=fa[top[y]];
}
if(dep[x]>dep[y])
swap(x,y);
res=max(res,queryMax(1,1,n,id[x],id[y]));
return res;
}
int getSum(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
swap(x,y);
res+=querySum(1,1,n,id[top[y]],id[y]);
y=fa[top[y]];
}
if(dep[x]>dep[y])
swap(x,y);
res+=querySum(1,1,n,id[x],id[y]);
return res;
}
int main()
{
init();
scanf("%d",&n);
for(int i=1; i<n; ++i)
{
int u,v;
scanf("%d%d",&u,&v);
addEdge(u,v);
}
dfs1(1,-1,1);
dfs2(1,1);
for(int i=1;i<=n;++i)
scanf("%d",&arr[i]);
build(1,1,n);
int q;
scanf("%d",&q);
while(q--)
{
char str[10];
int a,b;
scanf("%s%d%d",str,&a,&b);
if(str[0]=='C')
update(1,1,n,id[a],b);
else if(str[1]=='M')
printf("%d\n",getMax(a,b));
else if(str[1]=='S')
printf("%d\n",getSum(a,b));
}
return 0;
}