第一次写树链剖分的题目,下面说下我对树链剖分的理解
以spoj为例,题意是给你一棵树,有两种操作,一种是修改某条边的权值,一种是询问节点a到节点b之间的路径上所有边的最大路径。
所谓树链剖分,就是把树上所有的路径剖分成一条条重链,每个节点都只属于一条重链,然后给每一条重链上的节点在线段树上重新编号(一条重链上的节点在线段树上一定是连续的),这样就把在树上的操作转换成了在线段树上操作,可以在O(log(n))内完成某些操作,比如求最值,求和等。
剖分之后,需要预处理很多信息,如,每个节点depth[i]节点的深度,father[i]节点的父节点,每一个节点的size[i](表示以这个节点为根节点的子树的节点的数量),son[i]节点的重儿子(节点所有儿子中size最大的那个,若size相同,则任取一个),top[i]表示节点所在重链的头节点,id[i]为节点在线段树的编号等信息
处理以上所有的信息只需要两个dfs就可以
第一个dfs, dfs1处理father, size, depth, son
代码如下:
void dfs1(int u, int f, int d)//计算depth, size, son, father
{
depth[u] = d;
Size[u] = 1;
father[u] = f;
int Max = 0;
for(int i = 0; i < V[u].size(); i++)
{
int v = V[u][i];
if(v != f)
{
dfs1(v, u, d + 1);
Size[u] += Size[v];
if(Size[v] > Max)
{
Max = Size[v];
son[u] = v;
}
}
}
}
第二个dfs,dfs2处理top[i], id[i],构造出重链,代码如下:
void dfs2(int u, int tp)
{
top[u] = tp;
id[u] = ++cnt;
if(son[u] != -1)
{
dfs2(son[u], tp);
}
for(int i = 0; i < V[u].size(); i++)
{
int v = V[u][i];
if(v != father[u] && v != son[u])
{
dfs2(v, v);
}
}
}
预处理出上面所有信息之后就很简单了,假如要询问a , b之间权值最大的边,那么每次我们求出top[a], top[b],假如depth[top[a]] > depth[top[b]],则在线段树询问top[a]到a 所对应的区间,然后另a = father[top[a]],重复上面过程,直到两点在同一条重链上,然后做最后一次询问就行(注意:只有两个点在同一条重链上(编号在线段树上连续,我们才可以直接询问,否则需要转换,方法如上)
spoj完整ac代码如下:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 10005;
int N;//节点个数
struct edge{
int u, v;
int w;
}Edge[maxn];//储存边
vector<int> V[maxn];//储存图
int top[maxn];//该节点所在重链的第一个头节点
int depth[maxn];//节点的深度
int Size[maxn];//以该节点为根节点的子树所有节点的个
int father[maxn];//节点的父节点
int id[maxn];//节点的父边
int Rank[maxn];//给每条边重新在线段树上编号
int son[maxn];//节点的重儿子
int cnt;
void init()
{
for(int i = 0; i <= N; i++)
{
V[i].clear();
}
memset(son, -1, sizeof(son));
cnt = 0;
}
void dfs1(int u, int f, int d)//计算depth, size, son, father
{
depth[u] = d;
Size[u] = 1;
father[u] = f;
int Max = 0;
for(int i = 0; i < V[u].size(); i++)
{
int v = V[u][i];
if(v != f)
{
dfs1(v, u, d + 1);
Size[u] += Size[v];
if(Size[v] > Max)
{
Max = Size[v];
son[u] = v;
}
}
}
}
void dfs2(int u, int tp)
{
top[u] = tp;
id[u] = ++cnt;
if(son[u] != -1)
{
dfs2(son[u], tp);
}
for(int i = 0; i < V[u].size(); i++)
{
int v = V[u][i];
if(v != father[u] && v != son[u])
{
dfs2(v, v);
}
}
}
struct node{
int l, r;
int Max;
}Node[4 * maxn];
void pushUp(int i)
{
int lson = i<<1;
int rson = lson + 1;
Node[i].Max = max(Node[lson].Max, Node[rson].Max);
}
void build(int i, int l, int r)
{
Node[i].l = l;
Node[i].r = r;
Node[i].Max = 0;
if(l == r)
{
Node[i].Max = Rank[l];
return;
}
int f = i;
i <<= 1;
int mid = (l + r)>>1;
build(i, l, mid);
build(i|1, mid + 1, r);
pushUp(f);
}
void update(int i, int loc, int value)
{
if(Node[i].l == Node[i].r )
{
Node[i].Max = value;
return;
}
int f = i;
i <<= 1;
if(loc <= Node[i].r) update(i, loc, value);
else update(i|1, loc, value);
pushUp(f);
}
int query(int i, int l, int r)
{
if(Node[i].l == l && Node[i].r == r) return Node[i].Max;
i <<= 1;
if(r <= Node[i].r) return query(i, l, r);
else if(l >= Node[i|1].l) return query(i|1, l, r);
else return max(query(i, l, Node[i].r), query(i|1, Node[i|1].l, r));
}
int solve(int u, int v)
{
int f1 = top[u];
int f2 = top[v];
int ans = 0;
while(f1 != f2)
{
if(depth[f1] < depth[f2])
{
ans = max(query(1, id[f2], id[v]),ans);
v = father[f2];
}
else
{
ans = max(query(1, id[f1], id[u]), ans);
u = father[f1];
}
f1 = top[u];
f2 = top[v];
}
if(u == v) return ans;
if(depth[u] < depth[v]) swap(u, v);
ans = max(query(1, id[son[v]], id[u]), ans);
return ans;
}
int main()
{
int Case;
scanf("%d", &Case);
while(Case--)
{
scanf("%d", &N);
init();
for(int i = 1; i < N; i++)
{
scanf("%d%d%d", &Edge[i].u, &Edge[i].v, &Edge[i].w);
V[Edge[i].u].push_back(Edge[i].v);
V[Edge[i].v].push_back(Edge[i].u);
}
dfs1(1, 1, 0);
dfs2(1, 1);
for(int i = 1; i < N; i++)//求解每个点的父边
{
int u = Edge[i].u;
int v = Edge[i].v;
if(depth[u] < depth[v])
{
swap(u, v);
swap(Edge[i].u, Edge[i].v);
}
Rank[id[u]] = Edge[i].w;
}
Rank[1] = 0;
char op[10];
build(1, 1, cnt);
while(scanf("%s", op) == 1)
{
int u, v;
if(op[0] == 'D') break;
else if(op[0] == 'Q')
{
scanf("%d%d", &u, &v);
printf("%d\n", solve(u, v));
}
else if(op[0] == 'C')
{
scanf("%d%d", &u, &v);
int res = id[Edge[u].u];
update(1, res, v);
}
}
}
return 0;
}