树链剖分裸题。
题目链接
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 30500
using namespace std;
int n,tot,lable,Next[N*2],head[N*2],tree[N*2],a[N],m,fa[N],dep[N],son[N],size[N],tid[N],number[N];
int top[N],Max[N*8],sum[N*8];
bool visit[N];
void add(int x,int y)
{
tot++;
Next[tot]=head[x];
head[x]=tot;
tree[tot]=y;
}
void dfs(int x,int depth,int father)
{
visit[x]=true;fa[x]=father;dep[x]=depth;son[x]=0;size[x]=1;
int maxsize=0;
for (int i=head[x];i;i=Next[i])
if (!visit[tree[i]])
{
dfs(tree[i],depth+1,x);
size[x]+=size[tree[i]];
if (size[tree[i]]>maxsize)
{
maxsize=size[tree[i]];
son[x]=tree[i];
}
}
}
void dfs1(int x,int ancestor)
{
visit[x]=true;tid[x]=++lable;number[lable]=x;top[x]=ancestor;
if (son[x]!=0) dfs1(son[x],ancestor);
for (int i=head[x];i;i=Next[i])
if (!visit[tree[i]]) dfs1(tree[i],tree[i]);
}
void up(int x)
{
Max[x]=max(Max[x*2],Max[x*2+1]);
sum[x]=sum[x*2]+sum[x*2+1];
}
void build(int l,int r,int id)
{
Max[id]=-1<<29;sum[id]=0;
if (l==r)
{
Max[id]=sum[id]=a[number[l]];
return;
}
int mid=(l+r)/2;
build(l,mid,id*2);
build(mid+1,r,id*2+1);
up(id);
}
void change(int x,int l,int r,int id,int d)
{
if (l>x||r<x) return;
if (l==r&&l==x)
{
Max[id]=sum[id]=d;
return;
}
int mid=(l+r)/2;
change(x,l,mid,id*2,d);
change(x,mid+1,r,id*2+1,d);
up(id);
}
int query(int x,int y,int id,int l,int r,int q)
{
if (l>y||r<x)
{
if (q==0) return -1<<29;else return 0;
}
if (x<=l&&r<=y)
{
if (q==0) return Max[id];else return sum[id];
}
int mid=(l+r)/2;
if (q==0) return max(query(x,y,id*2,l,mid,q),query(x,y,id*2+1,mid+1,r,q));
else return query(x,y,id*2,l,mid,q)+query(x,y,id*2+1,mid+1,r,q);
up(id);
}
int Query(int x,int y,int q)
{
int ans;
if (q==0) ans=-1<<29;else ans=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
if (q==0) ans=max(ans,query(tid[top[x]],tid[x],1,1,n,0));
else ans+=query(tid[top[x]],tid[x],1,1,n,1);
x=fa[top[x]];
}
if (dep[x]>dep[y]) swap(x,y);
if (q==0) ans=max(ans,query(tid[x],tid[y],1,1,n,0));
else ans+=query(tid[x],tid[y],1,1,n,1);
return ans;
}
int main()
{
scanf("%d",&n);
tot=lable=0;
for (int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
scanf("%d",&m);
for (int i=1;i<=n;i++) visit[i]=false;
dfs(1,1,0);
for (int i=1;i<=n;i++) visit[i]=false;
dfs1(1,1);
build(1,n,1);
char s[100];
for (int i=1;i<=m;i++)
{
int x,y;
scanf("%s%d%d",s,&x,&y);
if (s[0]=='C') change(tid[x],1,n,1,y);
if (s[1]=='M') printf("%d\n",Query(x,y,0));
if (s[1]=='S') printf("%d\n",Query(x,y,1));
}
return 0;
}