题目描述
已知一棵包含 N 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作 1: 格式: 1 x y z 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
操作 2: 格式: 2 x y 表示求树从 x 到 y 结点最短路径上所有节点的值之和。
操作 3: 格式: 3 x z 表示将以 x 为根节点的子树内所有节点值都加上 z。
操作 4: 格式: 4 x 表示求以 x 为根节点的子树内所有节点值之和
输入格式
第一行包含 4 个正整数 N,M,R,P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。
接下来一行包含 N 个非负整数,分别依次表示各个节点上初始的数值。
接下来 N−1 行每行包含两个整数 x,y,表示点 x 和点 y 之间连有一条边(保证无环且连通)。
接下来 M 行每行包含若干个正整数,每行表示一个操作,格式如下:
操作 1: 1 x y z;
操作 2: 2 x y;
操作 3: 3 x z;
操作 4: 4 x。
输出格式
输出包含若干行,分别依次表示每个操作 2 或操作 4 所得的结果(对 P 取模)。
- 树链剖分详解:https://blog.csdn.net/a_forever_dream/article/details/80651308
- DFS序详解:https://blog.csdn.net/qq_36368339/article/details/79236467
树链剖分简介:
DFS序就是将树形结构转化为线性结构,用dfs遍历一遍这棵树,每棵子树内的点的dfs序都是连续的,如果先走重儿子,那么重链上的dfs序也是连续,我们就可以用线段树来维护这些结点的和。
(重儿子:一个节点的所有儿子中子树最大的儿子,由重儿子组成的链就是重链)
代码:
#include<algorithm>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<map>
#include<set>
#include<string>
#include<vector>
using namespace std;
#define LL long long
#define uLL unsigned long long
#define PII pair<int,int>
#define mid ((l + r)>>1)
#define chl (root<<1)
#define chr (root<<1|1)
const int manx = 1e5 + 10;
const int INF = 2e9;
const int mod = 1e4+7;
int N,M,root,P,a[manx],cou=0,tot=0;
int head[manx];
int siz[manx],deep[manx],son[manx],fa[manx];//记录子树的大小,结点的深度、重儿子、父亲结点
int top[manx],now[manx],ctr[manx],past[manx];//重链顶端,dfs序进入、出去的时间戳,dfs序编号i的原来的编号是past[i]
struct node
{
int e,bf;
}edge[manx<<1];
void add(int x,int y)
{
edge[cou]=node{y,head[x]};
head[x]=cou++;
}
void init()
{
cou=tot=0;
for(int i=0;i<=N;i++)
head[i]=-1,son[i]=0;
}
/********************(1)找重儿子********************/
void dfs_getson(int x)
{
siz[x]=1;//初始化树的大小
for(int i=head[x];~i;i=edge[i].bf){
int y=edge[i].e;
if(y==fa[x])continue;
fa[y]=x;
deep[y]=deep[x]+1;//记录父亲结点和深度
dfs_getson(y);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])son[x]=y;//更新重儿子
}
}
/********************(2)dfs序 给结点编号********************/
void dfs_rewrite(int x,int tp)//当前点以及当前点所在重链的顶端
{
top[x]=tp;
now[x]=++tot;///进入当前的时间戳
past[tot]=x;//记录编号为tot的是哪个点,线段树建树需要用
if(son[x])dfs_rewrite(son[x],tp);//每次先走重儿子
for(int i=head[x];~i;i=edge[i].bf){
int y=edge[i].e;
if(y!=son[x]&&y!=fa[x])dfs_rewrite(y,y);
}
ctr[x]=tot;///退出当前结点的时间戳
}
/********************(3)线段树维护区间和********************/
struct nod
{
LL sum,lazy;
}tree[manx<<2];
void eval(int root)
{
tree[root].sum=(tree[chl].sum+tree[chr].sum)%P;
}
void build_tree(int root,int l,int r)
{
tree[root].lazy=0;
if(l==r){
tree[root].sum=a[past[l]];//past[l]是dfs序编号为l的点原来的的编号
return;
}
build_tree(chl,l,mid);
build_tree(chr,mid+1,r);
eval(root);
}
void push_down(int root,int l,int r)
{
if(tree[root].lazy==0)return;
tree[chl].sum+=tree[root].lazy*(mid-l+1)%P;
tree[chr].sum+=tree[root].lazy*(r-mid)%P;
tree[chl].lazy+=tree[root].lazy;
tree[chr].lazy+=tree[root].lazy;
tree[root].lazy=0;
}
void change(int root,int l,int r,int ll,int rr,int val)
{
if(l==ll&&r==rr){
tree[root].sum+=val*(r-l+1);
tree[root].lazy+=val;
return;
}
push_down(root,l,r);
if(rr<=mid)
change(chl,l,mid,ll,rr,val);
else if(ll>mid)
change(chr,mid+1,r,ll,rr,val);
else change(chl,l,mid,ll,mid,val),change(chr,mid+1,r,mid+1,rr,val);
eval(root);
}
LL getsum(int root,int l,int r,int ll,int rr)
{
if(l==ll&&r==rr)
return tree[root].sum;
push_down(root,l,r);
if(rr<=mid)
return getsum(chl,l,mid,ll,rr);
else if(ll>mid)
return getsum(chr,mid+1,r,ll,rr);
else return (getsum(chl,l,mid,ll,mid)+getsum(chr,mid+1,r,mid+1,rr))%P;
}
/********************题面的四种操作********************/
void change_seg()//将树从 x 到 y 结点最短路径上所有节点的值都加上 z
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
while(top[x]!=top[y]){//不在同一条链上
if(deep[top[x]]>deep[top[y]])swap(x,y);
change(1,1,tot,now[top[y]],now[y],z);//修改深度大的点到所在重链的顶点值
y=fa[top[y]];//跳链
}
if(deep[x]>deep[y])swap(x,y);
change(1,1,tot,now[x],now[y],z);//在同一条链上时
}
void getsum_seg()//求树从 x 到 y 结点最短路径上所有节点的值之和
{//同上
int x,y;
LL ans=0;
scanf("%d%d",&x,&y);
while(top[x]!=top[y]){
if(deep[top[x]]>deep[top[y]])swap(x,y);
(ans+=getsum(1,1,tot,now[top[y]],now[y]))%=P;
y=fa[top[y]];
}
if(deep[x]>deep[y])swap(x,y);
ans+=getsum(1,1,tot,now[x],now[y]);
printf("%lld\n",ans%P);
}
void change_sontree()//将以 x 为根节点的子树内所有节点值都加上 z
{
int x,y;
scanf("%d%d",&x,&y);
change(1,1,tot,now[x],ctr[x],y);//子树上的点是一段连续的编号:now[x]~ctr[x]
}
void getsum_sontree()//求以 x 为根节点的子树内所有节点值之和
{
int x;
scanf("%d",&x);
printf("%lld\n",getsum(1,1,tot,now[x],ctr[x])%P);
}
int main()
{
scanf("%d%d%d%d",&N,&M,&root,&P);
init();
for(int i=1;i<=N;i++)
scanf("%lld",&a[i]);
for(int i=1;i<N;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);//建边
}
fa[root]=-1,deep[root]=0,siz[0]=0;
dfs_getson(root);//找重儿子并记录每个点的深度(后面跳链的时候让深度大的点先往上
dfs_rewrite(root,root);//跑一遍dfs序给结点编号,并记录每个点所在重链的顶端
build_tree(1,1,tot);//建树
while(M--){
int op;
scanf("%d",&op);
if(op==1)change_seg();
else if(op==2)getsum_seg();
else if(op==3)change_sontree();
else getsum_sontree();
}
return 0;
}
代码2(线段树动态开点):
#include<algorithm>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<map>
#include<set>
#include<string>
#include<vector>
using namespace std;
#define LL long long
#define uLL unsigned long long
#define PII pair<int,int>
#define mid ((l + r)>>1)
//#define chl (root<<1)
//#define chr (root<<1|1)
const int manx = 1e5 + 10;
const int INF = 2e9;
const int mod = 1e4+7;
int N,M,root,P,a[manx],cou=0,tot=0,len;
int head[manx];
int siz[manx],deep[manx],son[manx],fa[manx];//记录子树的大小,结点的深度、重儿子、父亲结点
int top[manx],now[manx],ctr[manx],past[manx];//重链顶端,dfs序进入、出去的时间戳,dfs序编号i的原来的编号是past[i]
struct node
{
int e,bf;
}edge[manx<<1];
void add(int x,int y)
{
edge[cou]=node{y,head[x]};
head[x]=cou++;
}
void init()
{
cou=tot=len=0;//边数,dfs序编号,线段树结点个数
for(int i=0;i<=N;i++)
head[i]=-1,son[i]=0;
}
/********************(1)找重儿子********************/
void dfs_getson(int x)
{
siz[x]=1;//初始化树的大小
for(int i=head[x];~i;i=edge[i].bf){
int y=edge[i].e;
if(y==fa[x])continue;
fa[y]=x;
deep[y]=deep[x]+1;//记录父亲结点和深度
dfs_getson(y);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])son[x]=y;//更新重儿子
}
}
/********************(2)dfs序 给结点编号********************/
void dfs_rewrite(int x,int tp)//当前点以及当前点所在重链的顶端
{
top[x]=tp;
now[x]=++tot;///进入当前的时间戳
past[tot]=x;//记录编号为tot的是哪个点,线段树建树需要用
if(son[x])dfs_rewrite(son[x],tp);//每次先走重儿子
for(int i=head[x];~i;i=edge[i].bf){
int y=edge[i].e;
if(y!=son[x]&&y!=fa[x])dfs_rewrite(y,y);
}
ctr[x]=tot;///退出当前结点的时间戳
}
/********************(3)线段树维护区间和********************/
struct nod
{
int chl,chr;
LL sum,lazy;
}tree[manx<<2];
void eval(int root)
{
int chl=tree[root].chl,chr=tree[root].chr;
tree[root].sum=(tree[chl].sum+tree[chr].sum)%P;
}
void build_tree(int l,int r)
{
++len;
tree[len].lazy=0;
if(l==r){
tree[len].sum=a[past[l]];//past[l]是dfs序编号为l的点原来的的编号
return;
}
int now=len;
tree[now].chl=len+1;
build_tree(l,mid);
tree[now].chr=len+1;
build_tree(mid+1,r);
eval(now);
}
void push_down(int root,int l,int r)
{
int chl=tree[root].chl,chr=tree[root].chr;
if(tree[root].lazy==0)return;
tree[chl].sum+=tree[root].lazy*(mid-l+1)%P;
tree[chr].sum+=tree[root].lazy*(r-mid)%P;
tree[chl].lazy+=tree[root].lazy;
tree[chr].lazy+=tree[root].lazy;
tree[root].lazy=0;
}
void change(int root,int l,int r,int ll,int rr,int val)
{
int chl=tree[root].chl,chr=tree[root].chr;
if(l==ll&&r==rr){
tree[root].sum+=val*(r-l+1);
tree[root].lazy+=val;
return;
}
push_down(root,l,r);
if(rr<=mid)
change(chl,l,mid,ll,rr,val);
else if(ll>mid)
change(chr,mid+1,r,ll,rr,val);
else change(chl,l,mid,ll,mid,val),change(chr,mid+1,r,mid+1,rr,val);
eval(root);
}
LL getsum(int root,int l,int r,int ll,int rr)
{
int chl=tree[root].chl,chr=tree[root].chr;
if(l==ll&&r==rr)
return tree[root].sum;
push_down(root,l,r);
if(rr<=mid)
return getsum(chl,l,mid,ll,rr);
else if(ll>mid)
return getsum(chr,mid+1,r,ll,rr);
else return (getsum(chl,l,mid,ll,mid)+getsum(chr,mid+1,r,mid+1,rr))%P;
}
/********************题面的四种操作********************/
void change_seg()//将树从 x 到 y 结点最短路径上所有节点的值都加上 z
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
while(top[x]!=top[y]){//不在同一条链上
if(deep[top[x]]>deep[top[y]])swap(x,y);
change(1,1,tot,now[top[y]],now[y],z);//修改深度大的点到所在重链的顶点值
y=fa[top[y]];//跳链
}
if(deep[x]>deep[y])swap(x,y);
change(1,1,tot,now[x],now[y],z);//在同一条链上时
}
void getsum_seg()//求树从 x 到 y 结点最短路径上所有节点的值之和
{//同上
int x,y;
LL ans=0;
scanf("%d%d",&x,&y);
while(top[x]!=top[y]){
if(deep[top[x]]>deep[top[y]])swap(x,y);
(ans+=getsum(1,1,tot,now[top[y]],now[y]))%=P;
y=fa[top[y]];
}
if(deep[x]>deep[y])swap(x,y);
ans+=getsum(1,1,tot,now[x],now[y]);
printf("%lld\n",ans%P);
}
void change_sontree()//将以 x 为根节点的子树内所有节点值都加上 z
{
int x,y;
scanf("%d%d",&x,&y);
change(1,1,tot,now[x],ctr[x],y);//子树上的点是一段连续的编号:now[x]~ctr[x]
}
void getsum_sontree()//求以 x 为根节点的子树内所有节点值之和
{
int x;
scanf("%d",&x);
printf("%lld\n",getsum(1,1,tot,now[x],ctr[x])%P);
}
int main()
{
scanf("%d%d%d%d",&N,&M,&root,&P);
init();
for(int i=1;i<=N;i++)
scanf("%lld",&a[i]);
for(int i=1;i<N;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);//建边
}
fa[root]=-1,deep[root]=0,siz[0]=0;
dfs_getson(root);//找重儿子并记录每个点的深度(后面跳链的时候让深度大的点先往上
dfs_rewrite(root,root);//跑一遍dfs序给结点编号,并记录每个点所在重链的顶端
build_tree(1,tot);//建树
while(M--){
int op;
scanf("%d",&op);
if(op==1)change_seg();
else if(op==2)getsum_seg();
else if(op==3)change_sontree();
else getsum_sontree();
}
return 0;
}