【金华集训 && 笔记】 Day 5 笔记——树链剖分

笔记

树链剖分:顾名思义,就是对树的链不断地剖,不断地分。
链的分类:
1、重链剖分
2、长链剖分

  • 重链剖分:每次从最重的儿子剖下去
    以sz[u]表示u的结点个数
  • 长链剖分:每次找最深的儿子剖下去

为啥重链剖分比长链剖分好???

  • 如果对于一棵树,每个节点的儿子都很深
  • 长链剖分就可能会经过 n \sqrt n n 条链
  • 直接崩掉

证明:重链剖分 O ( l o g n ) O(logn) O(logn)

  • 对于每个点,假设只有一个重儿子
  • s z [ u ] > = s z [ h e a v y u ] + s z [ v ] + 1 sz[u]>=sz[heavy_u]+sz[v]+1 sz[u]>=sz[heavyu]+sz[v]+1
  • 然后不会了。。
  • //抱歉,这里只是笔记,当听到这里的时候后面就没跟上。。将就着吧。

实现方式:
1、每条链建立一棵线段树

2、建一棵线段树

  • 先做一遍dfs,预处理出每个点的字子树个数以及重儿子编号等等
  • 每次dfs,优先访问重儿子
  • 好处???
    仍然是dfs序,重链都连在了一起,且是连续的一段
  • 然后就可以用top[x]记录x所属的重链的头号元素
  • 然后就可以用dfs序处理一系列问题
  • 具体实现请看程序。

听说树链剖分一般用来实现树上路径问题???


例题讲解

题目传送门
题目描述:

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

Solution

首先,这道题是求树上路径问题,我们这时就应该想到两个算法:点分治和树链剖分
因为,,这两个算法似乎都是用来求树上路径问题的???也许吧。。。

不过,这题是在线的,点分治似乎并不是那么好实现,所以我们想树链剖分。
这题其实就是树链剖分的模板题。

先将树求一遍dfs序,问题转化为区间求和、最大值问题。

考虑线段树。

dfs序、预处理等完善后,就可以进行树链剖分。
过程其实类似于倍增求LCA的过程

对于一段区间求和,将区间对应到树上,每次找两个点所对应的重链的顶端,不断对顶端到那个点的区间求和,一直往上递归即可,知道两点处于同一条链中,这时就可以求两点的区间。
用线段树很好实现。

Code
 #include<bits/stdc++.h>
 using namespace std;
 #define int long long
 const int N=60000;
 int n,m;
 int cnt=0,len=0;
struct node{
    int y,Next;
}e[N*2];
struct Node{
    int l,r,maxx,s;
};
int a[N*2];
int linkk[N*2];
int dfn[N*2];
int top[N*2];
int fa[N*2];
int pos[N*2];
int hv[N*2];
int sz[N*2];
int d[N*2];

void insert(int x,int y){
    e[++len]=(node){y,linkk[x]};
    linkk[x] = len;
} 

void dfshv(int x,int faa,int dd){
    d[x] = dd;
    sz[x] = 1;
    fa[x] = faa;
	for (int i=linkk[x];i;i=e[i].Next){
	    int y = e[i].y;
	    if (y == faa) continue;
	    dfshv(y,x,dd+1);
	    sz[x] += sz[y];
	    if (sz[hv[x]] < sz[y]) hv[x] = y;
	}
}


void dfs(int x,int rt){
	top[x] = rt;
	dfn[x] = ++cnt;
	pos[cnt] = x;
	if (hv[x] > 0) dfs(hv[x],rt);
	for (int i=linkk[x];i;i=e[i].Next){
	    int y=e[i].y;
	    if (y == fa[x] || y == hv[x]) continue;
	    dfs(y,y);
	}
}

struct Tr{
	Node tr[N*2+100];
	void build(int p,int l,int r){
	    tr[p].l=l,tr[p].r=r;
	    tr[p].s=0,tr[p].maxx=-10000000;
	    if (l == r) {tr[p].s = tr[p].maxx = a[pos[l]];return;}
	    int mid=l+r>>1;
	    build(p<<1,l,mid),build(p<<1|1,mid+1,r);
	    tr[p].maxx = max(tr[p<<1].maxx , tr[p<<1|1].maxx);
	    tr[p].s = tr[p<<1].s + tr[p<<1|1].s;
	}
	void change(int p,int x,int v){
	    int L=tr[p].l,R=tr[p].r;
	    if (L == R && L == x) {tr[p].maxx = tr[p].s = v;return;}
	    int mid = (L+R)>>1;
	    if (x<=mid) change(p<<1,x,v);
	    else change(p<<1|1,x,v);
	    tr[p].maxx = max(tr[p<<1].maxx , tr[p<<1|1].maxx);
	    tr[p].s = tr[p<<1].s + tr[p<<1|1].s;
	}
	int askm(int p,int l,int r){
	    int L=tr[p].l,R=tr[p].r;
	    if (L>r || R<l) return -10000000;
	    if (l<=L && R<=r) return tr[p].maxx;
	    int mid=L+R>>1;
	    int Maxx=-10000000;
	    if (l<=mid) Maxx=max(Maxx,askm(p<<1,l,r));
	    if (mid < r) Maxx=max(Maxx,askm(p<<1|1,l,r));
	    return Maxx;
	}
	int asks(int p,int l,int r){
	    int L=tr[p].l,R=tr[p].r;
	    if (L>r || R<l) return 0;
	    if (l<=L && R<=r) return tr[p].s;
	    int mid=L+R>>1;
	    int ss=0;
	    if (l<=mid) ss+=asks(p<<1,l,r);
	    if (mid < r) ss+=asks(p<<1|1,l,r);
	    return ss;
	}
};

Tr tr;

int qmax(int u,int v){
    int ans=-100000000;
    for (;top[u] != top[v];u=fa[top[u]]){
	    if (d[top[u]] < d[top[v]]) swap(u,v);
	    ans = max(ans , tr.askm(1,dfn[top[u]],dfn[u]));
	}
	if (d[u] > d[v]) swap(u,v);
	ans = max(ans,tr.askm(1,dfn[u],dfn[v]));
	return ans;
}

int qsum(int u,int v){
    int ans=0;
    for (;top[u] != top[v] ; u=fa[top[u]]){
	    if (d[top[u]] < d[top[v]]) swap(u,v);
	    ans += tr.asks(1,dfn[top[u]],dfn[u]);
	}
	if (d[u] > d[v]) swap(u,v);
	ans += tr.asks(1,dfn[u],dfn[v]);
	return ans;
}

signed main(){
    scanf("%lld",&n);
    for (int i=1,x,y;i<n;i++)
      scanf("%lld %lld",&x,&y),insert(x,y),insert(y,x);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
    dfshv(1,0,1);
    dfs(1,1);
    tr.build(1,1,n);
    scanf("%lld",&m);
    for (int i=1;i<=m;i++){
	    char c[10];
	    int x,y;
	    scanf("%s %lld %lld",c,&x,&y);
	    if (c[1] == 'H') tr.change(1,dfn[x],y);
	    if (c[1] == 'M') printf("%lld\n",qmax(x,y));
	    if (c[1] == 'S') printf("%lld\n",qsum(x,y));
	}
	return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值