【模板】树链剖分
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
树链刨分
个人理解:
树链刨分 把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链
dfs先走重链在走轻链
先附上关键代码
int depth[maxn],father[maxn],siz[maxn],son[maxn]; //depth为深度 father为父亲 son为重儿子
int in[maxn],out[maxn],num[maxn],top[maxn];//in为进入的时间戳 out为出去的时间戳 num为线性结构上对应的节点 top为当前链顶节点
int tim;
void dfs1(int u,int fa) 非叶子节点的重儿子编号son[]
{
father[u]=fa;//标记每个点的深度
siz[u]=1;//标记每个节点子树的大小
depth[u]=depth[fa]+1;//标记每个点的深度
int maxsize=-1;//记录中儿子的儿子数
for(int i=0; i<edge[u].size(); i++)
{
int v=edge[u][i];
if(v==fa)continue;//若为父亲continue
dfs1(v,u);//dfs其儿子
siz[u]+=siz[v];//把他的儿子数加到它身上
if(siz[v]>maxsize)
{
son[u]=v;
maxsize=siz[v];//标记每个重儿子的编号
}
}
}
void dfs2(int u,int fa) //这个fa指的是重链上的头指的轻儿子 即当前链最顶端的节点
{
in[u]=++tim;//标记每个点的新标号
num[tim]=u;//线性结构上对应的节点
top[u]=fa;//把这个点指向所在链的顶端
if(!son[u])//没有儿子 说明是叶子节点 记录返回的时间戳 return
{
out[u]=tim;
return ;
}
dfs2(son[u],fa);//先处理重儿子,在处理轻儿子的顺序递归
for(int i=0; i<edge[u].size(); i++)
{
int v=edge[u][i];
if(v==father[u]||v==son[u])continue;//如果是父亲或是重儿子continue
dfs2(v,v);//对于每个轻儿子都有一条从自己开始的链
}
out[u]=tim;
}
void uchain(int x,int y,int z) //修改链
{
while(top[x]!=top[y]) //当两个点不在同一条链上
{
if(depth[top[x]]<depth[top[y]])//把x点改为所在链顶端的深度更深的那个点
swap(x,y);
update(1,in[top[x]],in[x],z);//修改x点到x所在链顶端 这一段区间的点权和
x=father[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处在一条链上
if(depth[x]>depth[y]) //x在上面 x的深度小
swap(x,y);
update(1,in[x],in[y],z);//这时修改此时两个点的区间即可
}
int qchain(int x,int y) //查询链 同理
{
int ans=0;
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]])
swap(x,y);
ans+=qurey(1,in[top[x]],in[x]);
x=father[top[x]];
}
if(depth[x]>depth[y])
swap(x,y);
ans+=qurey(1,in[x],in[y]);
return ans;
}
完整代码
/*
将树从x到y结点最短路径上所有节点的值都加上z
求树从x到y结点最短路径上所有节点的值之和
将以x为根节点的子树内所有节点值都加上z
求以x为根节点的子树内所有节点值之和
*/
//树链刨分 把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链
//dfs先走重链在走轻链
#pragma GCC optimize(3,"Ofast","inline") //G++
#pragma comment(linker, "/STACK:102400000,102400000")
#pragma GCC optimize(2)
#include<bits/stdc++.h>
#include <ext/hash_map> //hashmap
#include <functional>
#define TEST freopen("C:\\Users\\hp\\Desktop\\ACM\\in.txt","r",stdin);
#define mem(a,x) memset(a,x,sizeof(a))
#define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace __gnu_cxx;
using namespace std;
typedef long long ll;
typedef unsigned long long ull; // %llu
const double PI = acos(-1.0);
const double eps = 1e-6;
//const int mod=1e9+7;
const int INF = -1u>>1;
const int maxn = 1e6+10;
struct node
{
int l,r,sum,lazy;
} t[maxn*4];
int a[maxn];
int depth[maxn],father[maxn],siz[maxn],son[maxn]; //depth为深度 father为父亲 son为重儿子
int in[maxn],out[maxn],num[maxn],top[maxn];//in为进入的时间戳 out为出去的时间戳 num为线性结构上对应的节点 top为当前链顶节点
int n,q,root,mod,tim; //
vector<int>edge[maxn];
void add(int x,int y)
{
edge[x].push_back(y);
edge[y].push_back(x);
}
void dfs1(int u,int fa) //标记每个点的深度depth[] 标记每个点的父亲father[] 标记每个非叶子节点的子树大小(含它自己)siz[] 标记每个非叶子节点的重儿子编号son[]
{
father[u]=fa;//标记每个点的深度
siz[u]=1;//标记每个节点子树的大小
depth[u]=depth[fa]+1;//标记每个点的深度
int maxsize=-1;//记录中儿子的儿子数
for(int i=0; i<edge[u].size(); i++)
{
int v=edge[u][i];
if(v==fa)continue;//若为父亲continue
dfs1(v,u);//dfs其儿子
siz[u]+=siz[v];//把他的儿子数加到它身上
if(siz[v]>maxsize)
{
son[u]=v;
maxsize=siz[v];//标记每个重儿子的编号
}
}
}
void dfs2(int u,int fa) //这个fa指的是重链上的头指的轻儿子 即当前链最顶端的节点
{
in[u]=++tim;//标记每个点的新标号
num[tim]=u;//线性结构上对应的节点
top[u]=fa;//把这个点指向所在链的顶端
if(!son[u])//没有儿子 说明是叶子节点 记录返回的时间戳 return
{
out[u]=tim;
return ;
}
dfs2(son[u],fa);//先处理重儿子,在处理轻儿子的顺序递归
for(int i=0; i<edge[u].size(); i++)
{
int v=edge[u][i];
if(v==father[u]||v==son[u])continue;//如果是父亲或是重儿子continue
dfs2(v,v);//对于每个轻儿子都有一条从自己开始的链
}
out[u]=tim;
}
//-------------------------------------- 以下为线段树
void pushup(int o)
{
t[o].sum=t[o<<1].sum+t[o<<1|1].sum;
}
void pushdown(int o)
{
if(t[o].lazy)
{
t[o].lazy%=mod;
t[o<<1].lazy+=t[o].lazy;
t[o<<1].lazy%=mod;
t[o<<1|1].lazy+=t[o].lazy;
t[o<<1|1].lazy%=mod;
t[o<<1].sum+=(t[o<<1].r-t[o<<1].l+1)*t[o].lazy;
t[o<<1].sum%=mod;
t[o<<1|1].sum+=(t[o<<1|1].r-t[o<<1|1].l+1)*t[o].lazy;
t[o<<1|1].sum%=mod;
t[o].lazy=0;
}
}
void build(int l,int r,int o)
{
t[o].l=l;
t[o].r=r;
t[o].lazy=0;
t[o].sum=0;
if(l==r)
{
t[o].sum=a[num[l]];
return ;
}
int mid=(l+r)>>1;
build(l,mid,o<<1);
build(mid+1,r,o<<1|1);
pushup(o);
}
void update(int o,int l,int r,int z)
{
if(t[o].l==l&&t[o].r==r)
{
t[o].lazy+=z;
t[o].lazy%=mod;
t[o].sum+=(t[o].r-t[o].l+1)*z;
t[o].sum%=mod;
return ;
}
pushdown(o);
int mid=(t[o].l+t[o].r)>>1;
if(r<=mid)
update(o<<1,l,r,z);
else if(l>mid)
update(o<<1|1,l,r,z);
else
{
update(o<<1,l,mid,z);
update(o<<1|1,mid+1,r,z);
}
pushup(o);
}
int qurey(int o,int l,int r)
{
if(t[o].l==l&&t[o].r==r)
{
return t[o].sum%mod;
}
pushdown(o);
int mid=(t[o].l+t[o].r)>>1;
if(r<=mid)
return qurey(o<<1,l,r);
else if(l>mid)
return qurey(o<<1|1,l,r);
else
{
return (qurey(o<<1,l,mid)+qurey(o<<1|1,mid+1,r))%mod;
}
}
//---------------------------------以上为线段树
void uchain(int x,int y,int z) //修改链
{
while(top[x]!=top[y]) //当两个点不在同一条链上
{
if(depth[top[x]]<depth[top[y]])//把x点改为所在链顶端的深度更深的那个点
swap(x,y);
update(1,in[top[x]],in[x],z);//修改x点到x所在链顶端 这一段区间的点权和
x=father[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处在一条链上
if(depth[x]>depth[y]) //x在上面 x的深度小
swap(x,y);
update(1,in[x],in[y],z);//这时修改此时两个点的区间即可
}
int qchain(int x,int y) //查询链 同理
{
int ans=0;
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]])
swap(x,y);
ans+=qurey(1,in[top[x]],in[x]);
x=father[top[x]];
}
if(depth[x]>depth[y])
swap(x,y);
ans+=qurey(1,in[x],in[y]);
return ans;
}
int main()
{
// TEST
ios;
cin>>n>>q>>root>>mod;
for(int i=1; i<=n; i++)
{
cin>>a[i];
a[i]%=mod;
}
for(int i=1; i<n; i++)
{
int x,y;
cin>>x>>y;
add(x,y);
}
dfs1(root,0);
dfs2(root,root);
build(1,n,1);
while(q--)
{
int op,x,y,z;
cin>>op;
if(op==1)
{
cin>>x>>y>>z;
z%=mod;
uchain(x,y,z);
}
else if(op==2)
{
cin>>x>>y;
cout<<qchain(x,y)%mod<<"\n";
}
else if(op==3)
{
cin>>x>>z;
z%=mod;
update(1,in[x],out[x],z); //in out之间的区间就是x的子树
}
else
{
cin>>x;
cout<<qurey(1,in[x],out[x])%mod<<"\n";
}
}
}