比较基础的树链剖分吧。。。
树链剖分的关键就在把在同一条链上的点的编号连续,这样就可以快速地用线段树维护修改与求值了(当然据说线段数组也可以)。
而我们找链的依据是什么?
这就要提到重链了。
首先我们提出一个概念:重儿子
重儿子就是儿子中子节点数最多的一个。重儿子的衔接就形成了重链。
如图:
其中,加粗黑线标出的是重链。
而我们只要不断地跳到当前链的顶端来查找与修改区间(每个节点只会位于一条链上)。
我们通过两次dfs来处理重链与重儿子,第一次找出重儿子,第二次找出重链并对节点进行新的编号,并且要处理出一个top[]数组,表示当前链的顶端,方便往上跳。
每次往上跳时先把深度较大也就是离lca较远的一个节点先往上跳,并且每次处理当前链上的点,重复操作,直到两个节点跳到同一条重链上,最后在两个节点间操作一次,修改或统计答案。
注:代码里线段树用的是动态开点
Code:
#include<cstdio>
#include<cstdlib>
#include<iostream>
using namespace std;
struct node{int y,next;}a[60010];
int first[30010],dep[30010],top[30010],tot[30010],son[30010],image[30010],fact[30010],fa[30010],num[30010];
int lc[60010],rc[60010],mmax[60010],sum[60010];
int len(0),root,d,v,n,m;
bool tf=false;
/*
a[],first[] 邻接表
树链剖分数组:
dep[] 当前节点的深度
top[] 当前节点所在链的顶端
tot[] 当前节点的子节点数量(包括自己)
son[] 当前节点的重儿子
fa[] 当前节点的父亲
image[] 对于每一个原本的编号,它剖分后的新编号(线段树上的编号)
fact[] 对于每一个线段树上的编号,它原本的编号
num[] 当前节点的值
线段树数组:
lc[] 当前节点的左儿子
rc[] 当前节点的右儿子
mmax[] 最大值
sum[] 和
*/
void ins(int x,int y){a[++len]=(node){y,first[x]};first[x]=len;}
void dfs_1(int x)
{
tot[x]=1;
for(int i=first[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=fa[x])
{
dep[y]=dep[x]+1;
fa[y]=x;
dfs_1(y);
if(tot[y]>tot[son[x]]) son[x]=y;
tot[x]+=tot[y];
}
}
}
void dfs_2(int x,int tp)
{
len++;
top[x]=tp;image[x]=len;fact[len]=x;
if(son[x]) dfs_2(son[x],tp);
for(int i=first[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=son[x] && y!=fa[x]) dfs_2(y,y);
}
}
void update(int &now,int l,int r)
{
if(now==0) now=++len;
sum[now]+=d;
mmax[now]=-1e9;
if(l==r)
{
if(tf) mmax[now]=d;
return;
}
int mid=(l+r)/2;
if(v<=mid) update(lc[now],l,mid);
else update(rc[now],mid+1,r);
mmax[now]=max(mmax[lc[now]],mmax[rc[now]]);
}
void change(int x,int y)
{
d=-num[x];v=image[x];tf=false;
update(root,1,n);
d=num[x]=y;tf=true;
update(root,1,n);
}
int findmax(int now,int x,int y,int l,int r)
{
if(x==l && y==r) return mmax[now];
int mid=(l+r)/2;
if(y<=mid) return findmax(lc[now],x,y,l,mid);
else if(mid<x) return findmax(rc[now],x,y,mid+1,r);
else return max(findmax(lc[now],x,mid,l,mid),findmax(rc[now],mid+1,y,mid+1,r));
}
void get_max(int x,int y)
{
int tx=top[x],ty=top[y],ans=-1e9;
while(tx!=ty)
{
if(dep[ty]<dep[tx]) //人为限定y是深度较大的一个
{
swap(tx,ty);
swap(x,y);
}
ans=max(ans,findmax(root,image[ty],image[y],1,n));
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
ans=max(ans,findmax(root,image[x],image[y],1,n));
printf("%d\n",ans);
}
int findsum(int now,int x,int y,int l,int r)
{
if(x==l && y==r) return sum[now];
int mid=(l+r)/2;
if(y<=mid) return findsum(lc[now],x,y,l,mid);
else if(mid<x) return findsum(rc[now],x,y,mid+1,r);
else return findsum(lc[now],x,mid,l,mid)+findsum(rc[now],mid+1,y,mid+1,r);
}
void get_sum(int x,int y)
{
int tx=top[x],ty=top[y],ans(0);
while(tx!=ty)
{
if(dep[ty]<dep[tx])
{
swap(tx,ty);
swap(x,y);
}
ans+=findsum(root,image[ty],image[y],1,n);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=findsum(root,image[x],image[y],1,n);
printf("%d\n",ans);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=2*n;i++) mmax[i]=-1e9;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
ins(x,y);
ins(y,x);
}
dep[1]=1;fa[1]=0;
dfs_1(1);
len=0;dfs_2(1,1);
len=0;
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
num[i]=d=x;
v=image[i];
tf=true;
update(root,1,n);
}
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
char s[10];
int x,y;
scanf("%s %d %d",s,&x,&y);
if(s[1]=='H') change(x,y);
if(s[1]=='M') get_max(x,y);
if(s[1]=='S') get_sum(x,y);
}
}