写在前面
首先,在学树链剖分之前最好先把 LCA、树形DP、DFS序 这三个知识点学了
emm还有必备的 链式前向星、线段树 也要先学了。
如果这三个知识点没掌握好的话,树链剖分难以理解也是当然的。
树链剖分
树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度
需要处理的问题:
- 将树从x到y结点最短路径上所有节点的值都加上z
- 求树从x到y结点最短路径上所有节点的值之和
- 将以x为根节点的子树内所有节点值都加上z
- 求以x为根节点的子树内所有节点值之和
目录:
-
概念
-
dfs1()
-
dfs2()
-
处理问题
-
对剖就的树建立线段树
- 重儿子:对于每一个非叶子节点,它的儿子中 以那个儿子为根的子树节点数最大的儿子 为该节点的重儿子 (Ps: 感谢@shzr大佬指出我此句话的表达不严谨qwq, 已修改)
- 轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
- 叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)
- 重边:一个父亲连接他的重儿子的边称为重边 //原写法:连接任意两个重儿子的边叫做重边
- 轻边:剩下的即为轻边
- 重链:相邻重边连起来的 连接一条重儿子 的链叫重链
- 对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
- 每一条重链以轻儿子为起点
dfs1()
这个dfs要处理几件事情:
- 标记每个点的深度dep[]
- 标记每个点的父亲fa[]
- 标记每个非叶子节点的子树大小(含它自己)
- 标记每个非叶子节点的重儿子编号son[]
void dfs1(int u, int f, int deep)//当前节点、父节点、层次深度
{
dep[u] = deep;//标记每个点的深度
sz[u] = 1;//标记每个非叶子节点的子树大小
fa[u] = f;//标记每个点的父亲
int maxSon = -1;//记录重儿子的儿子数
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == f)//若为父亲则continue
continue;
dfs1(v, u, deep + 1);//dfs其儿子
sz[u] += sz[v];//子节点的size已被处理,用它来更新父节点的size
if (sz[v] > maxSon) //选取size最大的作为重儿子
{
son[u] = v;
maxSon = sz[v];
}
}
}
dfs2()
这个dfs2也要预处理几件事情
- 标记每个点的新编号
- 赋值每个点的初始值到新编号上
- 处理每个点所在链的顶端
- 处理每条链
顺序:先处理重儿子再处理轻儿子,理由后面说
void dfs2(int u, int ttop)//u当前节点,ttpf当前链的最顶端的节点
{
id[u] = ++cn;//标记每个点的新编号
w[cn] = bufW[u];//把每个点的初始值赋到新编号上来
top[u] = ttop;//这个点所在链的顶端
if (!son[u])//如果没有儿子则返回
return;
dfs2(son[u], ttop);//按先处理重儿子,再处理轻儿子的顺序递归处理
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);//这个点位于轻链顶端,那么它的top必然为它本身
}
}
Attention 重要的来了!!!
前面说到dfs2的顺序是先处理重儿子再处理轻儿子
我们来模拟一下:
- 因为顺序是先重再轻,所以每一条重链的新编号是连续的
- 因为是dfs,所以每一个子树的新编号也是连续的
现在回顾一下我们要处理的问题
- 处理任意两点间路径上的点权和
- 处理一点及其子树的点权和
- 修改任意两点间路径上的点权
- 修改一点及其子树的点权
1、当我们要处理任意两点间路径时:
设所在链顶端的深度更深的那个点为x点
- ans加上x点到x所在链顶端 这一段区间的点权和
- 把x跳到x所在链顶端的那个点的上面一个点
不停执行这两个步骤,直到两个点处于一条链上,这时再加上此时两个点的区间和即可
这时我们注意到,我们所要处理的所有区间均为连续编号(新编号),于是想到线段树,用线段树处理连续编号区间和
每次查询时间复杂度为O(log2n)O(log2n)
int queryRange(int p1,int p2){
int ans = 0;
while (top[p1]!=top[p2])//当两个点不在同一条链上
{
if(dep[top[p1]] < dep[top[p2]])//把x点改为所在链顶端的深度更深的那个点
swap(p1,p2);
ans += query(1,id[top[p1]],id[p1]);//ans加上x点到x所在链顶端 这一段区间的点权和
ans %= mod;
p1 = fa[top[p1]];//把x跳到x所在链顶端的那个点的上面一个点
}
if(dep[p1] > dep[p2])//把x点深度更深的那个点
swap(p1,p2);
return (ans + query(1,id[p1],id[p2]))%mod;//这时再加上此时两个点的区间和即可
}
2、处理一点及其子树的点权和:
想到记录了每个非叶子节点的子树大小(含它自己),并且每个子树的新编号都是连续的
于是直接线段树区间查询即可
时间复杂度为O(logn)O(logn)
int querySon(int p){
return query(1,id[p],id[p]+sz[p]-1)%mod; //子树区间右端点为id[p]+sz[p]-1
}
当然,区间修改就和区间查询一样的啦~~
void updateRange(int p1,int p2,int val){
while (top[p1] != top[p2])
{
if(dep[top[p1]] < dep[top[p2]])
swap(p1,p2);
update(1,id[top[p1]],id[p1],val);
p1 = fa[top[p1]];
}
if(dep[p1] > dep[p2])
swap(p1,p2);
update(1,id[p1],id[p2],val);
}
void updateSon(int p,int val){
update(1,id[p],id[p]+sz[p]-1,val);
}
既然前面说到要用线段树,那么按题意建树就可以啦!
不过,建树这一步当然是在处理问题之前哦~
AC代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<vector>
#include<algorithm>
using namespace std;
#define inf 0x3f3f3f3f
const int maxn=1e6+3;
struct edge
{
int next, to;
} e[maxn];
int cnt,mod;
int head[maxn];
void addEdge(int u, int v)
{
e[++cnt] = {head[u], v};
head[u] = cnt;
}
struct node
{
int l, r, flag, w;
int dis() { return r - l + 1; }
int mid() { return (r + l) / 2; }
} a[maxn];
int rt;
int cn;
int w[maxn]; //新编号
int bufW[maxn];//表示各个节点上初始的数值
int id[maxn]; //表示结点x在线段树中的编号
int fa[maxn]; //表示节点x在树上的父亲
int top[maxn]; //表示节点s所在重链的顶部节点(深度最小)
int son[maxn]; //表示节点x的重儿子
int dep[maxn]; //表示节点x在树上的深度
int sz[maxn]; //表示节点x的子树的节点的个数
/*dfs1所要做的事情
标记每个点的深度dep[]
标记每个点的父亲fa[]
标记每个非叶子节点的子树大小(含它自己)
标记每个非叶子节点的重儿子编号son[]*/
void dfs1(int u, int f, int deep)//当前节点、父节点、层次深度
{
dep[u] = deep;//标记每个点的深度
sz[u] = 1;//标记每个非叶子节点的子树大小
fa[u] = f;//标记每个点的父亲
int maxSon = -1;//记录重儿子的儿子数
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == f)//若为父亲则continue
continue;
dfs1(v, u, deep + 1);//dfs其儿子
sz[u] += sz[v];//子节点的size已被处理,用它来更新父节点的size
if (sz[v] > maxSon) //选取size最大的作为重儿子
{
son[u] = v;
maxSon = sz[v];
}
}
}
/*dfs2所要做的事情
标记每个点的新编号
赋值每个点的初始值到新编号上
处理每个点所在链的顶端
处理每条链*/
void dfs2(int u, int ttop)//u当前节点,ttpf当前链的最顶端的节点
{
id[u] = ++cn;//标记每个点的新编号
w[cn] = bufW[u];//把每个点的初始值赋到新编号上来
top[u] = ttop;//这个点所在链的顶端
if (!son[u])//如果没有儿子则返回
return;
dfs2(son[u], ttop);//按先处理重儿子,再处理轻儿子的顺序递归处理
for (int i = head[u]; i; i = e[i].next)
{
int v = e[i].to;
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);//这个点位于轻链顶端,那么它的top必然为它本身
}
}
void build(int k, int l, int r)
{
a[k] = {l, r, 0, 0};
if (a[k].l == a[k].r)
{
a[k].w = w[r];
return;
}
int mid=(l+r)/2;
build(k*2, l, mid);
build(k*2+1,mid+1, r);
a[k].w = a[k*2].w + a[k *2+1].w;
a[k].w %= mod;
}
void down(int k)
{
a[k*2].w += a[k*2].dis() * a[k].flag;
a[k*2+1].w += a[k*2+1].dis() * a[k].flag;
a[k*2].flag += a[k].flag;
a[k*2+1].flag += a[k].flag;
a[k].flag = 0;
}
void update(int k, int l, int r, int val)
{
if (a[k].l >= l && a[k].r <= r)
{
a[k].w += val * a[k].dis();
a[k].flag += val;
return;
}
if (a[k].flag)
down(k);
if (a[k].mid() >= l)
update(k*2, l, r, val);
if (a[k].mid() < r)
update(k*2+1, l, r, val);
a[k].w = a[k*2].w + a[k*2+1].w;
a[k].w %= mod;
}
int query(int k, int l, int r)
{
if (a[k].l >= l && a[k].r <= r)
return a[k].w;
int res = 0;
if (a[k].flag)
down(k);
if (a[k].mid() >= l)
res += query(k*2, l, r);
if (a[k].mid() < r)
res += query(k*2+1, l, r);
return res % mod;
}
void updateRange(int p1,int p2,int val){
while (top[p1] != top[p2])
{
if(dep[top[p1]] < dep[top[p2]])
swap(p1,p2);
update(1,id[top[p1]],id[p1],val);
p1 = fa[top[p1]];
}
if(dep[p1] > dep[p2])
swap(p1,p2);
update(1,id[p1],id[p2],val);
}
int queryRange(int p1,int p2){
int ans = 0;
while (top[p1]!=top[p2])//当两个点不在同一条链上
{
if(dep[top[p1]] < dep[top[p2]])//把x点改为所在链顶端的深度更深的那个点
swap(p1,p2);
ans += query(1,id[top[p1]],id[p1]);//ans加上x点到x所在链顶端 这一段区间的点权和
ans %= mod;
p1 = fa[top[p1]];//把x跳到x所在链顶端的那个点的上面一个点
}
if(dep[p1] > dep[p2])//把x点深度更深的那个点
swap(p1,p2);
return (ans + query(1,id[p1],id[p2]))%mod;//这时再加上此时两个点的区间和即可
}
void updateSon(int p,int val){
update(1,id[p],id[p]+sz[p]-1,val);
}
int querySon(int p){
return query(1,id[p],id[p]+sz[p]-1)%mod; //子树区间右端点为id[p]+sz[p]-1
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n,k,m,rt;
cin >> n >> m >> rt >> mod;
//分别表示树的结点个数、操作个数、根节点序号和取模数
for(int i=1;i<=n;i++)
cin >> bufW[i];
int u,v;
for(int i=1;i<n;i++){
cin >> u >> v;
addEdge(u,v);
addEdge(v,u);
}
dfs1(rt,0,1);
dfs2(rt,rt);
build(1,1,n);
int l, r, w;
while (m--)
{
cin >> k;
if (k == 1)
{
cin >> l >> r >> w;
//从l到r结点最短路径上所有节点的值都加上w
updateRange(l, r, w);
}
else if (k == 2)
{
cin >> l >> r;
// 表示求树从l到r结点最短路径上所有节点的值之和
cout<<queryRange(l, r)<<endl;
}
else if (k == 3)
{
cin >> l >> w;
//表示将以l为根节点的子树内所有节点值都加上w
updateSon(l, w);
}
else
{
cin >> l;
//表示求以l为根节点的子树内所有节点值之和
cout<<querySon(l)<<endl;
}
}
return 0;
}