题目描述 传送门
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身。
学习了树链剖分,另外知道了线段树数组一般要开到4倍。
代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<climits>
#include<cmath>
#define INF INT_MAX/2
using namespace std;
typedef long long LL;
const int maxn=30010;
vector<int> g[maxn];
int sz[maxn],top[maxn],fa[maxn],depth[maxn],son[maxn],tot=0,w[maxn],n;
int sumv[100005],maxv[100005];
void dfs1(int u,int f){
sz[u]=1;
son[u]=0;
fa[u]=f;
depth[u]=depth[f]+1;
for(int i=0;i<g[u].size();i++) if(g[u][i]!=f){
dfs1(g[u][i],u);
if(sz[g[u][i]]>sz[son[u]]){
son[u]=g[u][i];
}
sz[u]+=sz[g[u][i]];
}
}
void dfs2(int u,int tp){
w[u]=++tot;
top[u]=tp;
if(!son[u]) return;
dfs2(son[u],tp);
for(int i=0;i<g[u].size();i++) if(g[u][i]!=son[u]&&g[u][i]!=fa[u])
dfs2(g[u][i],g[u][i]);
}
void change(int o,int L,int R,int x,int k){
if(L==R){
sumv[o]=k;
maxv[o]=k;
return ;
}
int M=L+R>>1;
if(x<=M) change(o<<1,L,M,x,k);
else change(o<<1|1,M+1,R,x,k);
sumv[o]=sumv[o<<1]+sumv[o<<1|1];
maxv[o]=max(maxv[o<<1],maxv[o<<1|1]);
}
int querymax(int o,int L,int R,int l,int r){
if(l<=L&&r>=R) return maxv[o];
int M=L+R>>1;
int ans=-INF;
if(l<=M) ans=max(ans,querymax(o<<1,L,M,l,r));
if(r>M) ans=max(ans,querymax(o<<1|1,M+1,R,l,r));
return ans;
}
int querysum(int o,int L,int R,int l,int r){
if(l<=L&&r>=R) return sumv[o];
int M=L+R>>1;
int ans=0;
if(l<=M) ans=ans+querysum(o<<1,L,M,l,r);
if(r>M) ans=ans+querysum(o<<1|1,M+1,R,l,r);
return ans;
}
int getmax(int u,int v){
int ans=-INF;
int f1=top[u],f2=top[v];
while(f1!=f2){
if(depth[f1]<depth[f2]){
swap(f1,f2);
swap(u,v);
}
ans=max(ans,querymax(1,1,tot,w[f1],w[u]));
u=fa[f1];
f1=top[u];
}
if(depth[u]<depth[v]) swap(u,v);
return max(ans,querymax(1,1,tot,w[v],w[u]));
}
int getsum(int u,int v){
int ans=0;
int f1=top[u],f2=top[v];
while(f1!=f2){
if(depth[f1]<depth[f2]){
swap(f1,f2);
swap(u,v);
}
ans+=querysum(1,1,tot,w[f1],w[u]);
u=fa[f1];
f1=top[u];
}
if(depth[u]<depth[v]) swap(u,v);
return ans+querysum(1,1,tot,w[v],w[u]);
}
int main(){
cin>>n;
memset(sumv,0,sizeof(sumv));
memset(maxv,0,sizeof(maxv));
for(int i=0;i<n-1;i++){
int a,b;
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=n;i++){
int a;
scanf("%d",&a);
change(1,1,tot,w[i],a);
}
int q,l,r;
char c[20];
cin>>q;
while(q--){
scanf("%s%d%d",c,&l,&r);
if(c[0]=='C') change(1,1,tot,w[l],r);
else if(c[1]=='M') printf("%d\n",getmax(l,r));
else printf("%d\n",getsum(l,r));
}
return 0;
}