看到一些求树上两结点之间的路径问题,马上就想到了树链剖分,再结合线段树或者树状数组搞一搞
#include<cstring>
#include<string>
#include<iostream>
#include<queue>
#include<cstdio>
#include<algorithm>
#include<map>
#include<cstdlib>
#include<cmath>
#include<vector>
//#pragma comment(linker, "/STACK:1024000000,1024000000");
using namespace std;
#define INF 0x3f3f3f3f
#define maxn 60006
int val[maxn],val1[maxn];
int fir[maxn],nex[maxn],v[maxn],e_max;
int son[maxn],fa[maxn],siz[maxn],pos[maxn],deep[maxn],top[maxn],tot;
long long mx[4*maxn],sum[4*maxn];
void init_()
{
memset(fir,-1,sizeof fir);
memset(son,-1,sizeof son);
memset(siz,0,sizeof siz);
e_max=0;
tot=1;
}
void add_edge(int s,int t)
{
int e=e_max++;
v[e]=t;
nex[e]=fir[s];
fir[s]=e;
}
void init(int l,int r,int k)
{
if(l==r)
{
sum[k]=mx[k]=val1[l];
return ;
}
int mid=l+r>>1;
init(l,mid,k<<1);
init(mid+1,r,k<<1|1);
sum[k]=sum[k<<1]+sum[k<<1|1];
mx[k]=max(mx[k<<1],mx[k<<1|1]);
}
void update(long long d,int s,int t,int l,int r,int k)
{
if(l==s&&r==t)
{
mx[k]=sum[k]=d;
return ;
}
int mid=l+r>>1;
if(t<=mid) update(d,s,t,l,mid,k<<1);
else if(s>mid) update(d,s,t,mid+1,r,k<<1|1);
else
{
update(d,s,mid,l,mid,k<<1);
update(d,mid+1,t,mid+1,r,k<<1|1);
}
sum[k]=sum[k<<1]+sum[k<<1|1];
mx[k]=max(mx[k<<1],mx[k<<1|1]);
}
long long query_sum(int s,int t,int l,int r,int k)
{
if(s==l&&r==t)
{
return sum[k];
}
int mid=l+r>>1;
if(t<=mid) return query_sum(s,t,l,mid,k<<1);
else if(s>mid)return query_sum(s,t,mid+1,r,k<<1|1);
else return query_sum(s,mid,l,mid,k<<1)+query_sum(mid+1,t,mid+1,r,k<<1|1);
}
long long query_mx(int s,int t,int l,int r,int k)
{
if(s==l&&r==t)
{
return mx[k];
}
int mid=l+r>>1;
if(t<=mid) return query_mx(s,t,l,mid,k<<1);
else if(s>mid)return query_mx(s,t,mid+1,r,k<<1|1);
else return max(query_mx(s,mid,l,mid,k<<1),query_mx(mid+1,t,mid+1,r,k<<1|1));
}
void dfs1(int k,int pre,int d)
{
deep[k]=d;
siz[k]++;
fa[k]=pre;
for(int i=fir[k]; ~i; i=nex[i])
{
int e=v[i];
if(e!=pre)
{
dfs1(e,k,d+1);
siz[k]+=siz[e];
if(son[k]==-1||siz[son[k]]<siz[e]) son[k]=e;
}
}
}
void dfs2(int k,int sp)
{
pos[k]=tot++;
top[k]=sp;
val1[pos[k]]=val[k];
if(son[k]==-1) return ;
dfs2(son[k],sp);
for(int i=fir[k]; ~i; i=nex[i])
{
int e=v[i];
if(e!=fa[k]&&e!=son[k])
{
dfs2(e,e);
}
}
}
void Query1(int s,int t)
{
long long ans=0;
int f1=top[s],f2=top[t];
while(f1!=f2)
{
if(deep[f1]<deep[f2]) swap(f1,f2),swap(s,t);
ans+=query_sum(pos[f1],pos[s],1,tot-1,1);
s=fa[f1];
f1=top[s];
}
if(deep[s]>deep[t]) swap(s,t);
ans+=query_sum(pos[s],pos[t],1,tot-1,1);
printf("%lld\n",ans);
}
void Query2(int s,int t)
{
long long ans=-INF;
int f1=top[s],f2=top[t];
while(f1!=f2)
{
if(deep[f1]<deep[f2]) swap(f1,f2),swap(s,t);
ans=max(ans,query_mx(pos[f1],pos[s],1,tot-1,1));
s=fa[f1];
f1=top[s];
}
if(deep[s]>deep[t]) swap(s,t);
ans=max(ans,query_mx(pos[s],pos[t],1,tot-1,1));
printf("%lld\n",ans);
}
int main()
{
int n;
while(scanf("%d",&n)!=EOF)
{
init_();
for(int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
add_edge(a,b);
add_edge(b,a);
}
for(int i=1; i<=n; i++) scanf("%d",&val[i]);
dfs1(1,-1,1);
dfs2(1,1);
init(1,tot-1,1);
int q;
scanf("%d",&q);
while(q--)
{
char s[10];
scanf("%s",s);
if(!strcmp(s,"QMAX"))
{
int l,r;
scanf("%d%d",&l,&r);
Query2(l,r);
}
else if(!strcmp(s,"QSUM"))
{
int l,r;
scanf("%d%d",&l,&r);
Query1(l,r);
}
else
{
int i,ti;
scanf("%d%d",&i,&ti);
update(ti,pos[i],pos[i],1,tot-1,1);
}
}
}
return 0;
}