题意:
给你一棵树,n个节点,每个节点都有一个val,现在进行q次操作,操作有三种,第一种查询u到v的最大值,第二种,查询u到v的val和,第三种,将某个点修改val
思路:
树链剖分基本题,相比起前面几道树链剖分,这个题维护的是点,虽然我们从理论上来说,树链剖分是根据边来剖分的,但是其实我们维护的时候维护点更简单,所以这个题虽然写起来麻烦一点,但是逻辑上反而更顺
用两颗线段树分别维护最大值和最小值,和维护边不同的是,在向上提的过程的边界处理,比如左右相等时也要继续比较
错误及反思:
代码:
#include<bits/stdc++.h>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
const int N =30010;
pair<int,int> segtree[N*4];
int top[N],son[N],first[N],fa[N],id[N],val[N],depth[N],si[N];
pair<int,int> edge[N],b[N];
int n,q,tot=0,tid=0;
void addedge(int x,int y)
{
edge[tot].first=x;
edge[tot].second=first[y];
first[y]=tot++;
edge[tot].first=y;
edge[tot].second=first[x];
first[x]=tot++;
}
void dfs1(int now,int bef,int dep)
{
depth[now]=dep;
fa[now]=bef;
si[now]=1;
for(int i=first[now];i!=-1;i=edge[i].second)
{
if(edge[i].first!=bef)
{
dfs1(edge[i].first,now,dep+1);
si[now]+=si[edge[i].first];
if(son[now]==-1) son[now]=edge[i].first;
else son[now]=si[edge[i].first]>si[son[now]]?edge[i].first:son[now];
}
}
}
void dfs2(int now,int tp)
{
id[now]=tid++;
top[now]=tp;
if(son[now]!=-1) dfs2(son[now],tp);
for(int i=first[now];i!=-1;i=edge[i].second)
if(edge[i].first!=fa[now]&&edge[i].first!=son[now])
dfs2(edge[i].first,edge[i].first);
}
void change1(int pos,int v,int l,int r,int rt)//largest
{
if(pos==l&&l==r)
{
segtree[rt].first=v;
return ;
}
int m=(l+r)/2;
if(m>=pos) change1(pos,v,lson);
if(m<pos) change1(pos,v,rson);
segtree[rt].first=max(segtree[rt<<1].first,segtree[rt<<1|1].first);
}
void change2(int pos,int v,int l,int r,int rt)//sum
{
if(pos==l&&l==r)
{
segtree[rt].second=v;
return ;
}
int m=(l+r)/2;
if(m>=pos) change2(pos,v,lson);
if(m<pos) change2(pos,v,rson);
segtree[rt].second=segtree[rt<<1].second+segtree[rt<<1|1].second;
}
int calmax(int L,int R,int l,int r,int rt)
{
if(L<=l&&R>=r)
return segtree[rt].first;
int m=(l+r)/2;
int ans=-1e9;
if(m>=L) ans=max(ans,calmax(L,R,lson));
if(m<R) ans=max(ans,calmax(L,R,rson));
return ans;
}
int calsum(int L,int R,int l,int r,int rt)
{
if(L<=l&&R>=r)
return segtree[rt].second;
int m=(l+r)/2;
int ans=0;
if(m>=L) ans+=calsum(L,R,lson);
if(m<R) ans+=calsum(L,R,rson);
return ans;
}
int query1(int L,int R)//max
{
int f1=top[L],f2=top[R];
int ans=-30001;
while(f1!=f2)
{
if(depth[f1]<depth[f2])
{
swap(L,R);
swap(f1,f2);
}
ans=max(ans,calmax(id[f1],id[L],0,tid-1,1));
L=fa[f1];
f1=top[L];
}
if(depth[L]>depth[R]) swap(L,R);
ans=max(ans,calmax(id[L],id[R],0,tid-1,1));
return ans;
}
int query2(int L,int R)//sum
{
int f1=top[L],f2=top[R];
int ans=0;
while(f1!=f2)
{
if(depth[f1]<depth[f2])
{
swap(L,R);
swap(f1,f2);
}
ans+=calsum(id[f1],id[L],0,tid-1,1);
L=fa[f1];
f1=top[L];
}
if(depth[L]>depth[R]) swap(L,R);
ans+=calsum(id[L],id[R],0,tid-1,1);
return ans;
}
int main()
{
memset(first,-1,sizeof(first));
memset(son,-1,sizeof(son));
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&b[i].first,&b[i].second);
addedge(b[i].first,b[i].second);
}
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
dfs1(1,1,1);
dfs2(1,1);
for(int i=1;i<=n;i++)
{
change1(id[i],val[i],0,tid-1,1);
change2(id[i],val[i],0,tid-1,1);
}
scanf("%d",&q);
while(q--)
{
char temp[20];
int ta,tb;
scanf("%s%d%d",temp,&ta,&tb);
if(temp[1]=='M')
printf("%d\n",query1(ta,tb));
else if(temp[1]=='S')
printf("%d\n",query2(ta,tb));
else
{
change1(id[ta],tb,0,tid-1,1);
change2(id[ta],tb,0,tid-1,1);
}
}
}