首先表示,以后算法,光看思路,尽量自己实现,大胆实现即可
好,对树链剖分有一点思路了
我的理解就是,把树拆成一条条链,把链连起来,用数据结构进行维护,几一般用线段树(多)或splay(相对少)维护
表示真练连代码能力,要注意细节
步骤:
dfs1——找出各个节点的深度,父亲,重儿子
dfs2——通过已经知道的重儿子,进行轻重边剖分,并对各节点重新编号,以便数据结构时维护(一条链中元素在序列中是相邻的)
build——初始化数据结构,并初始数据。
//以上预处理,成功把树拆分成链
开始询问更新
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
using namespace std;
const int maxn=300009;
struct aa{int to,pre;}edge[maxn];
struct aaa{int max,sum,l,r;}tree[maxn*4];
int n,p,head[maxn],w[maxn],tot;// 初始化
int fa[maxn],dep[maxn],size[maxn],son[maxn],rt;//dfs1用
int label,tid[maxn],top[maxn];//dfs2用
int root,kk;//线段树与树链剖分融合
void addedge(int from,int to)
{
edge[++tot].to=to;edge[tot].pre=head[from];head[from]=tot;
}
void dfs1(int u,int fat,int depth)
{
fa[u]=fat;dep[u]=depth;
int maxsize=0;size[u]=1;
for (int i=head[u];i;i=edge[i].pre) if (edge[i].to!=fat)
{
int v=edge[i].to;dfs1(v,u,depth+1);size[u]+=size[v];
if (size[v]>maxsize) maxsize=size[v],son[u]=v;
}
}
void dfs2(int u,int anc)
{
tid[u]=++label;top[u]=anc;
if (son[u]==0) return ;
dfs2(son[u],anc);
for (int i=head[u];i;i=edge[i].pre)
if (edge[i].to!=son[u]&&edge[i].to!=fa[u])
dfs2(edge[i].to,edge[i].to);
}
void build(int k,int l,int r)
{
tree[k].l=l;tree[k].r=r;
if (l==r) return ;
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
}
void updata(int k,int x,int t)
{
if (tree[k].l==tree[k].r) {tree[k].max=t;tree[k].sum=t;return ;}
int mid=(tree[k].l+tree[k].r)>>1;
if (mid>=x) updata(k<<1,x,t);else updata(k<<1|1,x,t);
tree[k].max=max(tree[k<<1].max,tree[k<<1|1].max);
tree[k].sum=tree[k<<1].sum+tree[k<<1|1].sum;
}
int mx(int k,int l,int r)
{
if (tree[k].l==l&&tree[k].r==r) return tree[k].max;
int mid=(tree[k].l+tree[k].r)>>1;
if (mid>=r) return mx(k<<1,l,r);
else if (mid<l) return mx(k<<1|1,l,r);
else return max(mx(k<<1,l,mid),mx(k<<1|1,mid+1,r));
}
int su(int k,int l,int r)
{
if (tree[k].l==l&&tree[k].r==r) return tree[k].sum;
int mid=(tree[k].l+tree[k].r)>>1;
if (mid>=r) return su(k<<1,l,r);
else if (mid<l) return su(k<<1|1,l,r);
else return su(k<<1,l,mid)+su(k<<1|1,mid+1,r);
}
int qmax(int u,int v)
{
int ans=-0x3f3f3f3f;
while (top[u]!=top[v])
{
if (dep[top[u]]<dep[top[v]])swap(u,v);
ans=max(ans,mx(1,tid[top[u]],tid[u]));
u=fa[top[u]];
}
if (dep[u]>dep[v])swap(u,v);
ans=max(ans,mx(1,tid[u],tid[v]));
return ans;
}
int qsum(int u,int v)
{
int ans=0;
while (top[u]!=top[v])
{
if (dep[top[u]]<dep[top[v]])swap(u,v);
ans=ans+su(1,tid[top[u]],tid[u]);
u=fa[top[u]];
}
if (dep[u]>dep[v])swap(u,v);
ans=ans+su(1,tid[u],tid[v]);
return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int a,b;scanf("%d%d",&a,&b);
addedge(a,b);addedge(b,a);
}
for (int i=1;i<=n;i++) scanf("%d",&w[i]);
rt=1;
dfs1(rt,0,1);
dfs2(rt,rt);
build(1,1,label);
for (int i=1;i<=label;i++) updata(1,tid[i],w[i]);
scanf("%d",&p);
for (int i=1;i<=p;i++)
{
char ch[100];int u,v,t;
scanf("%s",&ch);
if (ch[1]=='H') scanf("%d%d",&u,&t),
updata(1,tid[u],t);
if (ch[1]=='M') scanf("%d%d",&u,&v),
printf("%d\n",qmax(u,v));
if (ch[1]=='S') scanf("%d%d",&u,&v),
printf("%d\n",qsum(u,v));
}
return 0;
}