树链剖分

最近实验室的小伙伴搞数据结构搞得很厉害,网选的时候又出现了几个与树链剖分有关的题目,最近有时间,就也自己学习了一下。其实树链剖分并不是什么新的算法,只是dfs+线段树或者树状数组。只是当中的技巧性比较强,看懂之后不禁感慨是谁想出的这么高大上的算法。树链剖分,看名字就知道是需要对一棵树进行剖分。具体的剖分规则:先将边按照子树节点的总数量的多少分为重边和轻边(其中:子树节点最多的边为重边,其余的为轻边);然后按照所有相连的重边会组成若干条链。对这些链用树状数组或者线段树的处理。用来完成对树的快速的更新和查询工作。具体实现需要用到六个重要的数组。分别记录各个节点的父节点,节点深度,以该节点为根的子树的节点的个数,节点所在重链的顶点,以及各条边所对应的编号,编号要保证同一个重链上的边的序号要连续。便于用线段树或树状数组来维护。然后就将查询和更新交给线段树啦。

spoj375树链剖分,点更新和区间查询,用线段树实现;(点查询和区间更新)

#include<cstdio>
#include<climits>
#include<cstring>
#include<algorithm>
#include<vector>
#define MAX 100100
#define INF 0
 
using namespace std;
 
struct Edge{
	int from, to, dist;
};
 
vector<int>G[MAX];
vector<Edge>edges;
 
int fa[MAX],son[MAX],dep[MAX],siz[MAX],top[MAX],w[MAX];
 
void init(){
	edges.clear();
	for (int i = 0; i<MAX; i++) G[i].clear();
	memset(fa,0,sizeof(fa));
	memset(siz,0,sizeof(siz));
	memset(dep,0,sizeof(dep));
} 
 
void AddEdge(int from, int to, int dist){      
	edges.push_back((Edge){from, to, dist});
	edges.push_back((Edge){to,from,dist});
	int k = edges.size();
	G[from].push_back(k-2);
	G[to].push_back(k-1);
}
 
void dfs1(int u){          //dfs1求出fa,dep,siz,son;
	siz[u] = 1;
	son[u] = 0;
	for (int i = 0; i<G[u].size(); i++){
		Edge& e = edges[G[u][i]];
		if (e.to != fa[u]){
			fa[e.to] = u;
			dep[e.to] = dep[u] + 1;
			dfs1(e.to);
			if (siz[son[u]] < siz[e.to]) son[u] = e.to;
			siz[u] += siz[e.to];
		}
	}
}
 
int cnt;
 
void dfs2(int u, int pt){  //dfs2求出top,w;
	top[u] = pt;
	w[u] = ++cnt;
	if (son[u] != 0 ) dfs2(son[u],pt);     //主链优先搜,保证同一条连上的边的编号连续
	for (int i = 0; i<G[u].size(); i++){
		Edge& e = edges[G[u][i]];
		if (e.to != fa[u] && e.to != son[u]){
			dfs2(e.to,e.to);
		}
	}
}
 
struct Node{    
	int root,L,R;
	int maxv;
}a[MAX];
 
void build(int root, int l, int r){    
	if (l > r) return;
	a[root].L = l;
	a[root].R = r;
	a[root].maxv = -INF;
	
	if (l == r) return;
	
	int mid = a[root].L + (a[root].R-a[root].L) / 2;
	build(root*2+1, l, mid);
	build(root*2+2, mid+1,r);
}
 
void update(int root, int u, int x){       
	if (u < a[root].L || u > a[root].R) return;
	if (a[root].L == a[root].R){
		a[root].maxv = x;
		return;
	}
	 
    int mid = a[root].L + (a[root].R-a[root].L) / 2;
 	
 	if (u <= mid){
    	update(root*2+1, u, x);
    }
    else{
    	update(root*2+2, u, x);
    }
    
    a[root].maxv = max(a[root*2+1].maxv, a[root*2+2].maxv);
	return;
}
 
int query(int root, int l, int r){     
	
	if (l > a[root].R || r < a[root].L) return 0;
 
	if (a[root].L >= l && a[root].R <= r){
		return a[root].maxv;
	}
	
	int mid = a[root].L + (a[root].R - a[root].L) / 2;
	
	if (l > mid){
		return query(root*2+2,l,r);
	}
	else if (r <= mid){
		return query(root*2+1,l,r);
	}
	else{
		return max(query(root*2+1, l, mid), query(root*2+2,mid+1,r));
	}
}
 
int find(int u, int v){        //注意查询的技巧
	int f1 = top[u], f2 = top[v], ans = 0;   
	while (f1 != f2){             //当区间不在同一条连上时,需要一步一步来往树根方向查,并记录。直到区间在同一条重链上
		if (dep[f1] < dep[f2]){
			swap(f1,f2);
			swap(u,v);
		}
		ans = max(ans,query(0,w[f1],w[u]));
		u = fa[f1];
		f1 = top[u];
	}
	
	if (u != v){       //当区间在同一条重链上时,直接用线段树或树状数组维护
		if (dep[u] > dep[v]){
			swap(u,v);
		}
		ans = max(ans,query(0,w[son[u]],w[v]));
	}
	
	return ans;
	
}
 
int d[MAX][3];
 
int main(){
	int n,T;
	scanf("%d",&T);
	while(T--){
		scanf("%d",&n);
		init();
		int x,y,z;
		for (int i = 1; i<n; i++){
			scanf("%d%d%d",&x,&y,&z);
			d[i][0] = x;
			d[i][1] = y;
			d[i][2] = z;
			AddEdge(x,y,z);
		}
		
		int r = (n+1) / 2;
		dfs1(r);
		
		cnt = 0;
		dfs2(r, r);
		 
		build(0,1,cnt);
		
		for (int i = 1; i<n; i++){
			if (dep[d[i][0]] > dep[d[i][1]]) swap(d[i][0], d[i][1]);
			update(0,w[d[i][1]],d[i][2]);
		}
		
		char s[20];
        int u,v;
		
        while (scanf("%s",s) && strcmp(s,"DONE") != 0){
            scanf("%d%d",&u,&v);
            if (strcmp(s,"QUERY") == 0) printf("%d\n",find(u,v));
            else {
                update(0,w[d[u][1]],v);
            }
        }				
	} 
	
	return 0;	
}
 
 

hdu3966区间更新和点查询,用树状数组维护;

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<vector>
#define MAX 50010
#define ll __int64

using namespace std;

vector<int>G[MAX];

int d[MAX];
int fa[MAX],dep[MAX],siz[MAX],son[MAX],cnt;
int w[MAX],top[MAX];
ll c[MAX];

void init(int n){
	for (int i = 0; i<=n; i++) G[i].clear();
	memset(dep,0,sizeof(dep));
	memset(c,0,sizeof(c));
	memset(siz,0,sizeof(siz));
	memset(son,0,sizeof(son));
	cnt = 0;
}

void dfs1(int u){
	siz[u] = 1;
	son[u] = 0;
	for (int i = 0; i<G[u].size(); i++){
		int e = G[u][i];
		if (e != fa[u]){
			fa[e] = u;
			dep[e] = dep[u] + 1;
			dfs1(e);
			siz[u] += siz[e];
			if (siz[e] > siz[son[u]]){
				son[u] = e;
			}
		}
	}	
}

void dfs2(int u, int pt){
	w[u] = ++cnt;
	top[u] = pt;
	if (son[u] != 0) dfs2(son[u], pt);
	for (int i = 0; i<G[u].size(); i++){
		int e = G[u][i];
		if (e != fa[u] && e != son[u]){
			dfs2(e, e);
		}
	}
}

int lowbit(int x){
	return x&(-x);
}

void add(int u, int x){
	for (int i = u; i<=cnt; i += lowbit(i))
		c[i] += x;
}

ll sum(int u){
	ll s = 0;
	for (int i = u; i>0; i-= lowbit(i)){
		s += c[i];
	}
	return s;
}

void update(int x, int y, int z){
	int f1 = top[x], f2 = top[y];
	while (f1 != f2){
		if (dep[f1] < dep[f2]){
			swap(f1,f2);
			swap(x,y);
		}
	 
		add(w[f1], z);
		add(w[x]+1, -z);
		x = fa[f1];
		f1 = top[x];
	}
	
 
	if (dep[x] > dep[y]){
		swap(x,y);
	}
	 
	add(w[x], z);
	add(w[y]+1, -z);
}

int main(){
	int n,m,p;
	while (scanf("%d%d%d",&n,&m,&p) != EOF){
		init(n);
		for (int i = 1; i<=n; i++) scanf("%d",&d[i]);
		int x,y;
		for (int i = 0; i<m; i++){
			scanf("%d%d",&x,&y);
			G[x].push_back(y);
			G[y].push_back(x);
		}
		
		dfs1(1);
		dfs2(1,1);
		
	/*	for (int i = 1; i<=n; i++) printf("%d ",fa[i]); printf("\n");   //测试dfs1,dfs2
		for (int i = 1; i<=n; i++) printf("%d ",dep[i]); printf("\n");
		for (int i = 1; i<=n; i++) printf("%d ",siz[i]); printf("\n");
		for (int i = 1; i<=n; i++) printf("%d ",son[i]); printf("\n");
		for (int i = 1; i<=n; i++) printf("%d ",top[i]); printf("\n");
		for (int i = 1; i<=n; i++) printf("%d ",w[i]); printf("\n");*/
		
		for (int i = 1; i<=n; i++){
			add(w[i],d[i]);
			add(w[i]+1,-d[i]);
		}
				
		char s[10];
		int z;
		for (int i = 0; i<p; i++){
			scanf("%s",s);
			if (s[0] == 'I'){
				scanf("%d%d%d",&x,&y,&z);
				update(x,y,z);
			}
			else if (s[0] == 'D'){
				scanf("%d%d%d",&x,&y,&z);
				update(x,y,-z);
			}
			else{
				scanf("%d",&x);
				printf("%I64d\n",sum(w[x]));
			}
		}	
	}
	return 0;
}

hdu3966  用线段树维护;

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<vector>
#define MAX 50010
#define ll __int64

using namespace std;

vector<int>G[MAX];

int d[MAX];
int fa[MAX],dep[MAX],siz[MAX],son[MAX],cnt;
int w[MAX],top[MAX];

void init(int n){
	for (int i = 0; i<=n; i++) G[i].clear();
	memset(dep,0,sizeof(dep));
	memset(siz,0,sizeof(siz));
	memset(son,0,sizeof(son));
	cnt = 0;
}

void dfs1(int u){
	siz[u] = 1;
	son[u] = 0;
	for (int i = 0; i<G[u].size(); i++){
		int e = G[u][i];
		if (e != fa[u]){
			fa[e] = u;
			dep[e] = dep[u] + 1;
			dfs1(e);
			siz[u] += siz[e];
			if (siz[e] > siz[son[u]]){
				son[u] = e;
			}
		}
	}	
}

void dfs2(int u, int pt){
	w[u] = ++cnt;
	top[u] = pt;
	if (son[u] != 0) dfs2(son[u], pt);
	for (int i = 0; i<G[u].size(); i++){
		int e = G[u][i];
		if (e != fa[u] && e != son[u]){
			dfs2(e, e);
		}
	}
}

struct cNode{                                       
    int root,L,R;  
    ll sum,inc;                                     
}a[4*MAX];  

void build(int root, int L, int R){                
    a[root].L = L;  
    a[root].R = R;  
    a[root].inc = 0;
    a[root].sum = 0;
    if (L != R){  
        int M = L + (R - L) / 2;  
        build(root*2 + 1, L, M);  
        build(root*2 + 2, M + 1, R); 
        a[root].sum = a[root*2+1].sum + a[root*2+2].sum; 
    }  
}  
  
void update(int root, int s,int e, int v){   
          
    if (a[root].L == s && a[root].R == e){ 
        a[root].inc = a[root].inc + v;  
        return;  
    }  
    a[root].sum += v * (e-s+1);         
    int M = a[root].L + (a[root].R - a[root].L) / 2;  
    if (e <= M){  
        update(root*2 + 1, s, e, v);  
    }  
    else if (s > M){  
        update(root*2 + 2, s, e, v);  
    }  
    else{  
        update(root*2 + 1 , s, M, v);  
        update(root*2 + 2, M + 1, e, v);  
    }  
}  

void renew(int x, int y, int z){
	int f1 = top[x],f2 = top[y];
	while (f1 != f2){
		if (dep[f1] < dep[f2]){
			swap(f1,f2);
			swap(x,y);
		}
		update(0,w[f1],w[x],z);
		x = fa[f1];
		f1 = top[x];
	}	
	if (dep[x] < dep[y]){
		swap(x,y);
	}
	update(0,w[y],w[x],z);
}

ll sumv;  
  
void query(int root, int s, int e){  
    if (a[root].L == s && a[root].R == e){           
         sumv += a[root].sum + a[root].inc * (e - s + 1);     
         return;  
    }  
    if (a[root].L != a[root].R){   
        a[root*2+1].inc += a[root].inc;                 
        a[root*2+2].inc += a[root].inc;  
    }  
    a[root].sum += a[root].inc * (a[root].R - a[root].L + 1);  
    a[root].inc = 0;  
    int M = a[root].L + (a[root].R - a[root].L) / 2;  
    if (e <= M){  
        query(root*2 + 1 , s, e);  
    }  
    else if ( s > M){  
        query(root*2 + 2, s, e);  
    }  
    else{  
        query(root*2 + 1, s , M);  
        query(root*2 + 2, M + 1, e);  
    }  
} 

int main(){  
     int n,m,q;
     while (scanf("%d%d%d",&n,&m,&q) != EOF){
         init(n);
         for (int i = 1; i<=n; i++) scanf("%d",&d[i]);
         int x,y,z;
         for (int i = 0; i<m; i++){
             scanf("%d%d",&x,&y);
             G[x].push_back(y);
             G[y].push_back(x);
         }
         
         dfs1(1);
         dfs2(1,1);
         build(0,1,cnt);
         
         for (int i = 1; i<=n; i++){
         	update(0,w[i],w[i],d[i]);
         }
                  
         char s[5];
         for (int i = 0; i<q; i++){
             
             scanf("%s",s);
             if (s[0] == 'I'){
                 scanf("%d%d%d",&x,&y,&z);
                 renew(x,y,z);
             }    
             else if (s[0] == 'D'){
                 scanf("%d%d%d",&x,&y,&z);
                 renew(x,y,-z);
             }
             else{
                 scanf("%d",&x);
                 sumv = 0;
                 query(0,w[x],w[x]);
                 printf("%I64d\n",sumv);
             }
         }
         
     }
     return 0;
     
}  






  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值