最近实验室的小伙伴搞数据结构搞得很厉害,网选的时候又出现了几个与树链剖分有关的题目,最近有时间,就也自己学习了一下。其实树链剖分并不是什么新的算法,只是dfs+线段树或者树状数组。只是当中的技巧性比较强,看懂之后不禁感慨是谁想出的这么高大上的算法。树链剖分,看名字就知道是需要对一棵树进行剖分。具体的剖分规则:先将边按照子树节点的总数量的多少分为重边和轻边(其中:子树节点最多的边为重边,其余的为轻边);然后按照所有相连的重边会组成若干条链。对这些链用树状数组或者线段树的处理。用来完成对树的快速的更新和查询工作。具体实现需要用到六个重要的数组。分别记录各个节点的父节点,节点深度,以该节点为根的子树的节点的个数,节点所在重链的顶点,以及各条边所对应的编号,编号要保证同一个重链上的边的序号要连续。便于用线段树或树状数组来维护。然后就将查询和更新交给线段树啦。
spoj375树链剖分,点更新和区间查询,用线段树实现;(点查询和区间更新)
#include<cstdio>
#include<climits>
#include<cstring>
#include<algorithm>
#include<vector>
#define MAX 100100
#define INF 0
using namespace std;
struct Edge{
int from, to, dist;
};
vector<int>G[MAX];
vector<Edge>edges;
int fa[MAX],son[MAX],dep[MAX],siz[MAX],top[MAX],w[MAX];
void init(){
edges.clear();
for (int i = 0; i<MAX; i++) G[i].clear();
memset(fa,0,sizeof(fa));
memset(siz,0,sizeof(siz));
memset(dep,0,sizeof(dep));
}
void AddEdge(int from, int to, int dist){
edges.push_back((Edge){from, to, dist});
edges.push_back((Edge){to,from,dist});
int k = edges.size();
G[from].push_back(k-2);
G[to].push_back(k-1);
}
void dfs1(int u){ //dfs1求出fa,dep,siz,son;
siz[u] = 1;
son[u] = 0;
for (int i = 0; i<G[u].size(); i++){
Edge& e = edges[G[u][i]];
if (e.to != fa[u]){
fa[e.to] = u;
dep[e.to] = dep[u] + 1;
dfs1(e.to);
if (siz[son[u]] < siz[e.to]) son[u] = e.to;
siz[u] += siz[e.to];
}
}
}
int cnt;
void dfs2(int u, int pt){ //dfs2求出top,w;
top[u] = pt;
w[u] = ++cnt;
if (son[u] != 0 ) dfs2(son[u],pt); //主链优先搜,保证同一条连上的边的编号连续
for (int i = 0; i<G[u].size(); i++){
Edge& e = edges[G[u][i]];
if (e.to != fa[u] && e.to != son[u]){
dfs2(e.to,e.to);
}
}
}
struct Node{
int root,L,R;
int maxv;
}a[MAX];
void build(int root, int l, int r){
if (l > r) return;
a[root].L = l;
a[root].R = r;
a[root].maxv = -INF;
if (l == r) return;
int mid = a[root].L + (a[root].R-a[root].L) / 2;
build(root*2+1, l, mid);
build(root*2+2, mid+1,r);
}
void update(int root, int u, int x){
if (u < a[root].L || u > a[root].R) return;
if (a[root].L == a[root].R){
a[root].maxv = x;
return;
}
int mid = a[root].L + (a[root].R-a[root].L) / 2;
if (u <= mid){
update(root*2+1, u, x);
}
else{
update(root*2+2, u, x);
}
a[root].maxv = max(a[root*2+1].maxv, a[root*2+2].maxv);
return;
}
int query(int root, int l, int r){
if (l > a[root].R || r < a[root].L) return 0;
if (a[root].L >= l && a[root].R <= r){
return a[root].maxv;
}
int mid = a[root].L + (a[root].R - a[root].L) / 2;
if (l > mid){
return query(root*2+2,l,r);
}
else if (r <= mid){
return query(root*2+1,l,r);
}
else{
return max(query(root*2+1, l, mid), query(root*2+2,mid+1,r));
}
}
int find(int u, int v){ //注意查询的技巧
int f1 = top[u], f2 = top[v], ans = 0;
while (f1 != f2){ //当区间不在同一条连上时,需要一步一步来往树根方向查,并记录。直到区间在同一条重链上
if (dep[f1] < dep[f2]){
swap(f1,f2);
swap(u,v);
}
ans = max(ans,query(0,w[f1],w[u]));
u = fa[f1];
f1 = top[u];
}
if (u != v){ //当区间在同一条重链上时,直接用线段树或树状数组维护
if (dep[u] > dep[v]){
swap(u,v);
}
ans = max(ans,query(0,w[son[u]],w[v]));
}
return ans;
}
int d[MAX][3];
int main(){
int n,T;
scanf("%d",&T);
while(T--){
scanf("%d",&n);
init();
int x,y,z;
for (int i = 1; i<n; i++){
scanf("%d%d%d",&x,&y,&z);
d[i][0] = x;
d[i][1] = y;
d[i][2] = z;
AddEdge(x,y,z);
}
int r = (n+1) / 2;
dfs1(r);
cnt = 0;
dfs2(r, r);
build(0,1,cnt);
for (int i = 1; i<n; i++){
if (dep[d[i][0]] > dep[d[i][1]]) swap(d[i][0], d[i][1]);
update(0,w[d[i][1]],d[i][2]);
}
char s[20];
int u,v;
while (scanf("%s",s) && strcmp(s,"DONE") != 0){
scanf("%d%d",&u,&v);
if (strcmp(s,"QUERY") == 0) printf("%d\n",find(u,v));
else {
update(0,w[d[u][1]],v);
}
}
}
return 0;
}
hdu3966区间更新和点查询,用树状数组维护;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<vector>
#define MAX 50010
#define ll __int64
using namespace std;
vector<int>G[MAX];
int d[MAX];
int fa[MAX],dep[MAX],siz[MAX],son[MAX],cnt;
int w[MAX],top[MAX];
ll c[MAX];
void init(int n){
for (int i = 0; i<=n; i++) G[i].clear();
memset(dep,0,sizeof(dep));
memset(c,0,sizeof(c));
memset(siz,0,sizeof(siz));
memset(son,0,sizeof(son));
cnt = 0;
}
void dfs1(int u){
siz[u] = 1;
son[u] = 0;
for (int i = 0; i<G[u].size(); i++){
int e = G[u][i];
if (e != fa[u]){
fa[e] = u;
dep[e] = dep[u] + 1;
dfs1(e);
siz[u] += siz[e];
if (siz[e] > siz[son[u]]){
son[u] = e;
}
}
}
}
void dfs2(int u, int pt){
w[u] = ++cnt;
top[u] = pt;
if (son[u] != 0) dfs2(son[u], pt);
for (int i = 0; i<G[u].size(); i++){
int e = G[u][i];
if (e != fa[u] && e != son[u]){
dfs2(e, e);
}
}
}
int lowbit(int x){
return x&(-x);
}
void add(int u, int x){
for (int i = u; i<=cnt; i += lowbit(i))
c[i] += x;
}
ll sum(int u){
ll s = 0;
for (int i = u; i>0; i-= lowbit(i)){
s += c[i];
}
return s;
}
void update(int x, int y, int z){
int f1 = top[x], f2 = top[y];
while (f1 != f2){
if (dep[f1] < dep[f2]){
swap(f1,f2);
swap(x,y);
}
add(w[f1], z);
add(w[x]+1, -z);
x = fa[f1];
f1 = top[x];
}
if (dep[x] > dep[y]){
swap(x,y);
}
add(w[x], z);
add(w[y]+1, -z);
}
int main(){
int n,m,p;
while (scanf("%d%d%d",&n,&m,&p) != EOF){
init(n);
for (int i = 1; i<=n; i++) scanf("%d",&d[i]);
int x,y;
for (int i = 0; i<m; i++){
scanf("%d%d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
}
dfs1(1);
dfs2(1,1);
/* for (int i = 1; i<=n; i++) printf("%d ",fa[i]); printf("\n"); //测试dfs1,dfs2
for (int i = 1; i<=n; i++) printf("%d ",dep[i]); printf("\n");
for (int i = 1; i<=n; i++) printf("%d ",siz[i]); printf("\n");
for (int i = 1; i<=n; i++) printf("%d ",son[i]); printf("\n");
for (int i = 1; i<=n; i++) printf("%d ",top[i]); printf("\n");
for (int i = 1; i<=n; i++) printf("%d ",w[i]); printf("\n");*/
for (int i = 1; i<=n; i++){
add(w[i],d[i]);
add(w[i]+1,-d[i]);
}
char s[10];
int z;
for (int i = 0; i<p; i++){
scanf("%s",s);
if (s[0] == 'I'){
scanf("%d%d%d",&x,&y,&z);
update(x,y,z);
}
else if (s[0] == 'D'){
scanf("%d%d%d",&x,&y,&z);
update(x,y,-z);
}
else{
scanf("%d",&x);
printf("%I64d\n",sum(w[x]));
}
}
}
return 0;
}
hdu3966 用线段树维护;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<vector>
#define MAX 50010
#define ll __int64
using namespace std;
vector<int>G[MAX];
int d[MAX];
int fa[MAX],dep[MAX],siz[MAX],son[MAX],cnt;
int w[MAX],top[MAX];
void init(int n){
for (int i = 0; i<=n; i++) G[i].clear();
memset(dep,0,sizeof(dep));
memset(siz,0,sizeof(siz));
memset(son,0,sizeof(son));
cnt = 0;
}
void dfs1(int u){
siz[u] = 1;
son[u] = 0;
for (int i = 0; i<G[u].size(); i++){
int e = G[u][i];
if (e != fa[u]){
fa[e] = u;
dep[e] = dep[u] + 1;
dfs1(e);
siz[u] += siz[e];
if (siz[e] > siz[son[u]]){
son[u] = e;
}
}
}
}
void dfs2(int u, int pt){
w[u] = ++cnt;
top[u] = pt;
if (son[u] != 0) dfs2(son[u], pt);
for (int i = 0; i<G[u].size(); i++){
int e = G[u][i];
if (e != fa[u] && e != son[u]){
dfs2(e, e);
}
}
}
struct cNode{
int root,L,R;
ll sum,inc;
}a[4*MAX];
void build(int root, int L, int R){
a[root].L = L;
a[root].R = R;
a[root].inc = 0;
a[root].sum = 0;
if (L != R){
int M = L + (R - L) / 2;
build(root*2 + 1, L, M);
build(root*2 + 2, M + 1, R);
a[root].sum = a[root*2+1].sum + a[root*2+2].sum;
}
}
void update(int root, int s,int e, int v){
if (a[root].L == s && a[root].R == e){
a[root].inc = a[root].inc + v;
return;
}
a[root].sum += v * (e-s+1);
int M = a[root].L + (a[root].R - a[root].L) / 2;
if (e <= M){
update(root*2 + 1, s, e, v);
}
else if (s > M){
update(root*2 + 2, s, e, v);
}
else{
update(root*2 + 1 , s, M, v);
update(root*2 + 2, M + 1, e, v);
}
}
void renew(int x, int y, int z){
int f1 = top[x],f2 = top[y];
while (f1 != f2){
if (dep[f1] < dep[f2]){
swap(f1,f2);
swap(x,y);
}
update(0,w[f1],w[x],z);
x = fa[f1];
f1 = top[x];
}
if (dep[x] < dep[y]){
swap(x,y);
}
update(0,w[y],w[x],z);
}
ll sumv;
void query(int root, int s, int e){
if (a[root].L == s && a[root].R == e){
sumv += a[root].sum + a[root].inc * (e - s + 1);
return;
}
if (a[root].L != a[root].R){
a[root*2+1].inc += a[root].inc;
a[root*2+2].inc += a[root].inc;
}
a[root].sum += a[root].inc * (a[root].R - a[root].L + 1);
a[root].inc = 0;
int M = a[root].L + (a[root].R - a[root].L) / 2;
if (e <= M){
query(root*2 + 1 , s, e);
}
else if ( s > M){
query(root*2 + 2, s, e);
}
else{
query(root*2 + 1, s , M);
query(root*2 + 2, M + 1, e);
}
}
int main(){
int n,m,q;
while (scanf("%d%d%d",&n,&m,&q) != EOF){
init(n);
for (int i = 1; i<=n; i++) scanf("%d",&d[i]);
int x,y,z;
for (int i = 0; i<m; i++){
scanf("%d%d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
}
dfs1(1);
dfs2(1,1);
build(0,1,cnt);
for (int i = 1; i<=n; i++){
update(0,w[i],w[i],d[i]);
}
char s[5];
for (int i = 0; i<q; i++){
scanf("%s",s);
if (s[0] == 'I'){
scanf("%d%d%d",&x,&y,&z);
renew(x,y,z);
}
else if (s[0] == 'D'){
scanf("%d%d%d",&x,&y,&z);
renew(x,y,-z);
}
else{
scanf("%d",&x);
sumv = 0;
query(0,w[x],w[x]);
printf("%I64d\n",sumv);
}
}
}
return 0;
}