树链剖分的基础题
因为复习到了这个部分突然发现竟然没有题解所以现在补一个。。
一些基础的东西。。。
重儿子:siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
然后简单的用两次dfs计算出每个节点的father,deep,size, son,w,top
其他简单的就不说了
w表示的是当前节点与其付清节点的连边在线段树中的位置
top表示的是当前节点所在的链的顶端的节点
其他的东西都比较简单,代码可以着重看下查询那段,因为不用求LCA所以不错。。。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define inf 0x7fffffff/4
#define MAX 123333
#define rep(i,j,k) for(int i=j;i<=k;i++)
using namespace std;
int n,m,father[MAX],deep[MAX],size[MAX],w[MAX],value[MAX],head[2*MAX];
int next[2*MAX],to[2*MAX],top[MAX],Max[MAX],sum[MAX];
int b[MAX],son[MAX],num=0,tot=0,done[MAX];
int l[4*MAX],r[4*MAX];
void add(int from,int To)
{
to[++tot]=To;
next[tot]=head[from];
head[from]=tot;
}
void dfs(int fa,int x,int dep)
{
deep[x]=dep;
size[x]=1;
son[x]=0;
for(int i=head[x];i;i=next[i])
if(to[i]!=fa)
{
father[to[i]]=x;
dfs(x,to[i],dep+1);
if(size[to[i]]>size[son[x]])
son[x]=to[i];
size[x]+=size[to[i]];
}
}
void dfs2(int tp,int x)
{
w[x]=++num;
b[num]=value[x];
top[x]=tp;
if(son[x])
dfs2(top[x],son[x]);
for(int i=head[x];i;i=next[i])
if(to[i]!=son[x]&&to[i]!=father[x])
dfs2(to[i],to[i]);
}
inline void up(int x)
{
if(l[x]==r[x])
return;
Max[x]=max(Max[2*x],Max[2*x+1]);
sum[x]=sum[2*x]+sum[2*x+1];
}
inline void build(int x,int la,int ra)
{
l[x]=la;
r[x]=ra;
if(la==ra)
{
Max[x]=sum[x]=b[la];
return;
}
int mid=(la+ra)>>1;
build(2*x,la,mid);
build(2*x+1,mid+1,ra);
up(x);
}
void change(int x,int pos,int number)
{
if(l[x]==r[x]&&l[x]==pos)
{
Max[x]=sum[x]=number;
return;
}
int mid=(l[x]+r[x])>>1;
if(pos<=mid)
change(2*x,pos,number);
else
change(2*x+1,pos,number);
up(x);
}
int query_max(int x,int la,int ra)
{
if(l[x]>ra||r[x]<la)
return -inf;
if(la<=l[x]&&r[x]<=ra)
return Max[x];
int mid=(l[x]+r[x])>>1;
return max(query_max(2*x,la,ra),query_max(2*x+1,la,ra));
}
inline int ask_max(int x,int y)
{
int ret=-inf;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
swap(x,y);
ret=max(ret,query_max(1,w[top[x]],w[x]));
x=father[top[x]];
}
if(w[x]<w[y])
swap(x,y);
ret=max(ret,query_max(1,w[y],w[x]));
return ret;
}
int query(int x,int la,int ra)
{
if(l[x]>ra||r[x]<la)
return 0;
if(la<=l[x]&&r[x]<=ra)
return sum[x];
return query(2*x,la,ra)+query(2*x+1,la,ra);
}
inline int ask(int x,int y)
{
int ret=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
swap(x,y);
ret+=query(1,w[top[x]],w[x]);
x=father[top[x]];
}
if(w[x]<w[y])
swap(x,y);
return ret+query(1,w[y],w[x]);
}
int main()
{
scanf("%d",&n);
rep(i,1,n-1)
{
int a1,a2;
scanf("%d%d",&a1,&a2);
add(a1,a2),add(a2,a1);
}
rep(i,1,n)
scanf("%d",&value[i]);
dfs(0,1,1);
memset(done,0,sizeof(done));
dfs2(1,1);
build(1,1,num);
for(scanf("%d",&m);m;m--)
{
getchar();
char ask_[90];
int a1,a2;
scanf("%s%d%d",ask_,&a1,&a2);
if(ask_[0]=='C')
change(1,w[a1],a2);
else
if(ask_[1]=='M')
printf("%d\n",ask_max(a1,a2));
else
printf("%d\n",ask(a1,a2));
}
return 0;
}