又是一道裸题,刚学会树链剖分,存个模版。
#include<iostream>
#include<cstdio>
using namespace std;
struct bian{
int to;
}b[60005];
int fst[30005],nxt[60005];
int tot=1;
void build(int f,int t)
{
b[++tot].to=t;
nxt[tot]=fst[f];
fst[f]=tot;
}
int sd[30005],fa[30005];
int sz[30005],zson[30005];
void dfs(int u)
{
int v,ans=0;
for(int i=fst[u];i;i=nxt[i])
{
v=b[i].to;
if(!sd[v])
{
sd[v]=sd[u]+1;
fa[v]=u;
dfs(v);
sz[u]+=sz[v];
if(sz[v]>ans)
{
ans=sz[v];
zson[u]=v;
}
}
}
sz[u]++;
}
int dfs_clock;
int ts[30005],top[30005];
int intr[30005];
void dfs_2(int u,int x)
{
ts[u]=++dfs_clock;
intr[dfs_clock]=u;
top[u]=x;
if(!zson[u])
return ;
dfs_2(zson[u],x);
int v;
for(int i=fst[u];i;i=nxt[i])
{
v=b[i].to;
if(v==fa[u]||v==zson[u])
continue;
top[v]=v;
dfs_2(v,v);
}
}
int w[30005];
struct xds{
int l,r,sum,maxx;
}tree[150005];
void up(int dq)
{
tree[dq].sum=tree[dq<<1].sum+tree[dq<<1|1].sum;
tree[dq].maxx=max(tree[dq<<1].maxx,tree[dq<<1|1].maxx);
}
void build(int dq,int l,int r)
{
tree[dq].l=l;
tree[dq].r=r;
if(l==r)
{
tree[dq].sum=tree[dq].maxx=w[intr[l]];
return ;
}
int mid=(l+r)>>1;
build(dq<<1,l,mid);
build(dq<<1|1,mid+1,r);
up(dq);
}
void change(int dq,int l,int r,int d)
{
if(tree[dq].l==l&&tree[dq].r==r)
{
tree[dq].sum=d;
tree[dq].maxx=d;
return ;
}
int mid=(tree[dq].l+tree[dq].r)>>1;
if(mid>=l) change(dq<<1,l,r,d);
if(mid<r) change(dq<<1|1,l,r,d);
up(dq);
}
int inf=1e9;
int ask_max(int dq,int l,int r)
{
if(tree[dq].l>=l&&tree[dq].r<=r)
return tree[dq].maxx;
int mid=(tree[dq].l+tree[dq].r)>>1;
int ans=-inf;
if(mid>=l) ans=ask_max(dq<<1,l,r);
if(mid<r) ans=max(ans,ask_max(dq<<1|1,l,r));
return ans;
}
int ask_sum(int dq,int l,int r)
{
if(tree[dq].l>=l&&tree[dq].r<=r)
return tree[dq].sum;
int mid=(tree[dq].l+tree[dq].r)>>1;
int ans=0;
if(mid>=l) ans+=ask_sum(dq<<1,l,r);
if(mid<r) ans+=ask_sum(dq<<1|1,l,r);
return ans;
}
int find_max(int x,int y)
{
int ans=-inf;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(sd[fx]<sd[fy])
{
swap(x,y);
swap(fx,fy);
}
ans=max(ans,ask_max(1,ts[fx],ts[x]));
x=fa[fx];fx=top[x];
}
if(sd[x]>sd[y]) swap(x,y);
ans=max(ans,ask_max(1,ts[x],ts[y]));
return ans;
}
int find_sum(int x,int y)
{
int ans=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(sd[fx]<sd[fy])
{
swap(x,y);
swap(fx,fy);
}
ans+=ask_sum(1,ts[fx],ts[x]);
x=fa[fx];fx=top[x];
}
if(sd[x]>sd[y]) swap(x,y);
ans+=ask_sum(1,ts[x],ts[y]);
return ans;
}
char s[10];
int main()
{
int n;
scanf("%d",&n);
int a,b;
for(int i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
build(a,b);
build(b,a);
}
sd[1]=1;fa[1]=1;
dfs(1);top[1]=1;
dfs_2(1,1);
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
build(1,1,n);
int q;
scanf("%d",&q);
while(q--)
{
scanf("%s%d%d",s,&a,&b);
if(s[1]=='M')
printf("%d\n",find_max(a,b));
if(s[1]=='S')
printf("%d\n",find_sum(a,b));
if(s[1]=='H')
change(1,ts[a],ts[a],b);
}
}