(同步个人博客http://sxysxy.org/blogs/3 到csdn)
题目大意:
有一棵有N个节点,N-1条边的树,每个节点都有一个权值w。现在要求支持以下操作: CHANGE u t 把节点u的权值修改为t,QMAX u v 询问从节点u到v路径上权值最大的节点,QSUM询问从节点u到节点v上所有节点的权值之和。(N <= 3W ,修改/询问次数 <= 20W)
LCT并不会写orz,于是树链剖分,这也是我第一道用树链剖分AC的题,纪念之。
(备注,本题可以在http://syzoj.com/problem/47 , http://cojs.tk/cogs/problem/problem.php?pid=1688, 以及 http://www.lydsy.com/JudgeOnline/problem.php?id=1036 上提交)
首先这个树链剖分的大体思路是这样:树上两个节点之间的路径是唯一的(不绕圈,每条边只经过一次),也就是说对于同一对点u,v他们之间路途覆盖的范围是不变的,现在要求维护这个范围上的信息,于是第一时间就想到线段树。
然而线段树维护的区间总是在某种意义上连续的。我们就需要给点(这道题中维护的是点,树链剖分也可以维护边的信息)进行编号。树链剖分中一个重要的步骤就是给节点(边)编号。剖分后的树链(顺便一说:树链即树上的路径)有着特殊的性质,能够使得我们能在较少时间内完成树上的查询维护操作。下边对此进行一些讲解。
废话完毕
我们定义 size[x]表示以x为根节点的子树的节点数,deep[x]表示x在的深度(这里设树根的深度为1),parent[x]为x的父节点,son[x]表示x节点的“重儿子“(重儿子指的是x的子节点中size最大的节点),top[x]为x所在的链的顶端的节点。 另外定义两个辅助的数组,intree[x]表示树中节点x在线段树(用来维护树上信息的线段树)中的节点号,outree是intree的反函数,outree[x]表示线段树中节点x对应要维护的树中的节点号。上面这些信息可以通过对树进行两次dfs得到。(也能bfs,过程见后面给出的代码)
剖分时要将树链剖分成轻链和重链,重链是由重边组成的链,节点x与son[x]的连边是一条重边,x与他的其余子节点的连边为轻边。轻链就是轻边组成的链。
下面是一个剖分的树的图例:
红色的边的重边,绿色的是轻边
剖分后的树具有以下性质:
- 1.对于树中两个节点u, v,若u是v的子节点(或者说deep[u] > deep[v]),且u, v之间的连边是轻边,那么必定有size[u]*2 < size[v]
- 2.从根节点到树中任意一点的路径上,存在不超过logN条轻链和不超过logN条重链。(N是节点总数)
- 3.树中重链上的节点(或者边,这要看具体问题)在线段树重的intree值是连续的。也就是说重链上的节点在线段树中的编号形成了一段连续的区间,我们是可以直接对这个区间进行查询的。
- 4.同一条链上的节点中,deep小的一定是deep大的的节点的父节点,这就意味着同一条链上deep小的节点比deep大的节点在线段树中的intree值小(这是废话,但是在查询是这一点很有用,线段树接受的查询的区间[l,r]满足l<=r,我们需要保证这一点,而保证这一点的方法就是通过deep值判断。。。)。
修改x节点的权值就直接修改线段树里面intree[x]的值,能学到树剖想必线段树这方面不需要过多赘述。
- 查询的时候,对于节点u,v,如果u,v在同一条链上(即top[u] == top[v]),直接向线段树查询就可以。
- 当top[u] != top[v]时,记t1 = top[u], t2 = top[v], 我们钦定deep[t1] >= deep[t2](若不满足则交换u,v,交换t1, t2满足),。此时查询t1到u(t1是u所在链的顶端,即deep[t1] <= deep[u],即线段树中intree[t1] <= intree[u],满足线段树查询的要求(见性质4),可以进行查询。)查询完毕及时更新答案,然后令u = parent[t1], t1 = top[u],t1与u之间形成一段新的链(即查询的时候一段一段地查),对这段链再查询。。。重复以上操作直到top[u] == top[v],这时候就可以直接向线段树查询了,这是最后一次更新答案,完成查询。
(PS:对于需要维护树上边的信息的树链剖分,可以实行”边化点”,让每条边两端较深的一点代替这条边。之后就和维护树上点的一样啦。
下面给出本题我的AC代码:
#include <cstdio>
#include <cstdlib>
#include <cstdarg>
#include <cstring>
#include <string>
#include <vector>
#include <queue>
#include <list>
#include <algorithm>
using namespace std;
#define MAXN (30010)
#define BETTER_CODE __attribute__((optimize("O3")))
vector<int> G[MAXN];
int top[MAXN], son[MAXN], parent[MAXN], value[MAXN], size[MAXN], deep[MAXN];
int intree[MAXN], outree[MAXN];
int num;
BETTER_CODE
void dfs1(int cur, int fa, int dep)
{
parent[cur] = fa;
deep[cur] = dep;
size[cur] = 1;
son[cur] = 0;
for(int i = 0; i < G[cur].size(); i++)
{
int nx = G[cur][i];
if(nx != fa)
{
dfs1(nx, cur, dep+1);
size[cur] += size[nx];
if(size[son[cur]] < size[nx])
son[cur] = nx;
}
}
}
BETTER_CODE
void dfs2(int cur, int tp)
{
top[cur] = tp;
intree[cur] = ++num;
outree[intree[cur]] = cur;
if(!son[cur])return;
dfs2(son[cur], tp);
for(int i = 0; i < G[cur].size(); i++)
{
int nx = G[cur][i];
if(nx == parent[cur] || nx == son[cur])continue;
dfs2(nx, nx);
}
}
class segtree
{
public:
struct node
{
int l, r;
int maxi;
int sum;
}ns[MAXN<<2];
#define mid(a,b) ((a+b)>>1)
#define ls(x) (x<<1)
#define rs(x) ((x<<1)|1)
BETTER_CODE
void build(int c, int l, int r)
{
ns[c].l = l;
ns[c].r = r;
if(l == r)
{
ns[c].maxi = value[outree[l]];
ns[c].sum = value[outree[l]];
return;
}
int m = mid(l, r);
build(ls(c), l, m);
build(rs(c), m+1, r);
ns[c].maxi = max(ns[ls(c)].maxi, ns[rs(c)].maxi);
ns[c].sum = ns[ls(c)].sum + ns[rs(c)].sum;
}
BETTER_CODE
void update(int c, int v)
{
if(ns[c].l == ns[c].r)
{
ns[c].maxi = value[outree[ns[c].l]];
ns[c].sum = value[outree[ns[c].l]];
return;
}else if(v <= ns[ls(c)].r)
update(ls(c), v);
else
update(rs(c), v);
ns[c].maxi = max(ns[ls(c)].maxi, ns[rs(c)].maxi);
ns[c].sum = ns[ls(c)].sum + ns[rs(c)].sum;
}
BETTER_CODE
int askmax(int c, int l, int r)
{
int t = ls(c);
if(l == ns[c].l && r == ns[c].r)
return ns[c].maxi;
else if(r <= ns[t].r)
return askmax(t, l, r);
else if(l >= ns[t|1].l)
return askmax(t|1, l, r);
else if(l <= ns[t].r && r >= ns[t|1].l)
return max(askmax(t, l, ns[t].r), askmax(t|1, ns[t|1].l, r));
}
BETTER_CODE
int asksum(int c, int l, int r)
{
int t = ls(c);
if(l == ns[c].l && r == ns[c].r)
return ns[c].sum;
else if(r <= ns[t].r)
return asksum(t, l, r);
else if(l >= ns[t|1].l)
return asksum(t|1, l, r);
else if(l <= ns[t].r && r >= ns[t|1].l)
return asksum(t, l, ns[t].r)+asksum(t|1, ns[t|1].l, r);
}
};
segtree ST;
BETTER_CODE
int querymax(int u, int v)
{
int t1 = top[u];
int t2 = top[v];
int ans = -0x2333333;
while(t1 != t2)
{
//假设t1比t2深,这里如果发现deep[t1]<deep[t2]则交换
if(deep[t1] < deep[t2])
{
swap(t1, t2);
swap(u, v);
}
ans = max(ans, ST.askmax(1, intree[t1], intree[u]));
u = parent[t1];
t1 = top[u];
}
//then t1 == t2
if(deep[u] > deep[v])
ans = max(ans, ST.askmax(1, intree[v], intree[u]));
else
ans = max(ans, ST.askmax(1, intree[u], intree[v]));
return ans;
}
BETTER_CODE
int querysum(int u, int v)
{
int t1 = top[u];
int t2 = top[v];
int ans = 0;
while(t1 != t2)
{
//假设t1比t2深,这里如果发现deep[t1]<deep[t2]则交换
if(deep[t1] < deep[t2])
{
swap(t1, t2);
swap(u, v);
}
ans += ST.asksum(1, intree[t1], intree[u]);
u = parent[t1];
t1 = top[u];
}
//then t1 == t2
if(deep[u] > deep[v])
ans += ST.asksum(1, intree[v], intree[u]);
else
ans += ST.asksum(1, intree[u], intree[v]);
return ans;
}
void change(int t, int v)
{
value[t] = v;
ST.update(1, intree[t]);
}
BETTER_CODE
int main()
{
int n;
scanf("%d", &n);
for(int i = 1; i < n; i++)
{
int a, b;
scanf("%d %d", &a, &b);
G[a].push_back(b);
G[b].push_back(a);
}
for(int i = 1; i <= n; i++)
scanf("%d", value + i);
num = 0;
dfs1(1, 0, 1);
dfs2(1, 1);
ST.build(1, 1, num);
char buf[233];
int q;
scanf("%d", &q);
while(q--)
{
int x, y;
scanf("%s", buf);
scanf("%d %d", &x, &y);
if(buf[0] == 'Q')
{
if(buf[1] == 'M')
printf("%d\n", querymax(x, y));
else
printf("%d\n", querysum(x, y));
}else if(buf[0] == 'C')
change(x, y);
}
return 0;
}