题目
本人太菜,只会树链剖分的做法。
现在列出学树链剖分的预备知识
dfs序 线段树
现在先简单介绍一下树链剖分的主要操作。
1求树上两点的最短距离。
2将树上某两点最短路径上的所有点的点权都加x。
3将以某节点为根的子树上的点都加x。
4求以某节点为根节点的子树内所有节点权值之和。
介绍完树链剖分的用途。
接下来介绍一下树链剖分的七个常量数组。
数组 | 含义 |
---|---|
dfn[] | 节点的dfs序 (即节点的新编号) |
pre[] | 保存节点的老编号 |
size[] | 以x为根的子树大小 |
d[] | 节点x的深度 |
son[] | x的子节点中size[]最大的节点 |
fa[] | 保存x节点的父节点 |
top[] | 保存x节点的链头(稍后会详细介绍) |
了解了这几个常量数组后接下来再来科普几个概念
概念 | 含义 |
---|---|
重儿子 | 父亲节点的所有儿子中子树结点数目最多的结点 |
轻儿子 | 父亲节点中除了重儿子以外的所有儿子节点 |
重边 | 父亲结点和重儿子连成的边 |
轻边 | 父亲节点和轻儿子连成的边 |
重链 | 由多条重边连接而成的路径 |
轻链 | 由多条轻边连接而成的路径 |
这些概念最好画图来帮助理解
接下来介绍几个函数
计算四个量 fa d size son
void dfs1(int x,int f){
fa[x]=f;
d[x]=d[f]+1;
size[x]=1;
for(int i=head[x];i ;i=next[i]){
int v=ver[i];
if(v==f) continue;
dfs1(v,x);
size[x]+=size[v];
if(son[x]==-1||size[son[x]]<size[v])
son[x]=v;
}
}
计算后三个量 top dfn pre
void dfs2(int x,int tp){
top[x]=tp;
dfn[x]=++cnt;
pre[cnt]=x;
if(son[x]==-1) return ;
if(son[x])
dfs2(son[x],tp);//保证重边都在一条链上
for(int i=head[x]; i ;i=next[i]){
int v=ver[i];
if(v==fa[x]||v==son[x]) continue;
dfs2(v,v);
}
}
将值上传的操作
void up(int p){
t[p].data=t[p<<1].data+t[(p<<1)|1].data;
}
线段树建树
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
if(l==r){
t[p].data=val[pre[l]];
return ;
}
int mid=l+r>>1;
build(p<<1,l,mid);
build((p<<1)|1,mid+1,r);
up(p);
}
线段树查询
long long sum(int p,int l,int r){
if(l<=t[p].l&&r>=t[p].r){
return t[p].data;
}
push_down(p);
long long ans=0;
int mid=t[p].l+t[p].r>>1;
if(l<=mid)
ans+=sum(p<<1,l,r);
if(r>mid)
ans+=sum((p<<1)|1,l,r);
return ans;
}
线段树修改
void add(int p,int l,int r,long long c){
if(l<=t[p].l&&r>=t[p].r){
laz[p]+=c;
t[p].data+=c*(t[p].r-t[p].l+1);
return ;
}
push_down(p);
int mid=t[p].l+t[p].r>>1;
if(l<=mid)
add(p<<1,l,r,c);
if(r>mid)
add((p<<1)|1,l,r,c);
up(p);
}
求路径的长度
long long getsum(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=sum(1,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=sum(1,dfn[x],dfn[y]);
return ans;
}
有了这几个函数就可以解决这个题了。
1求树上两点的最短距离。
getsum(x,y);
2将树上某两点最短路径上的所有点的点权都加x。
3将以某节点为根的子树上的点都加x。
add(1,dfn[u],dfn[u]+size[x]-1,x);
4求以某节点为根节点的子树内所有节点权值之和。
sum(1,dfn[x],dfn[x]+size[x]-1);
以下是全代码
#include<bits/stdc++.h>
using namespace std;
const int ll=8e5;
int n,m,val[ll];
int dfn[ll],pre[ll],fa[ll],d[ll],son[ll],size[ll],top[ll],cnt;
int head[ll],next[ll],ver[ll],tot;
long long laz[ll<<2];
struct node{
int l,r;
long long data;
}t[ll<<2];
void add(int u,int v){
tot++;
ver[tot]=v;
next[tot]=head[u];
head[u]=tot;
}
void up(int p){
t[p].data=t[p<<1].data+t[(p<<1)|1].data;
}
void push_down(int p){
if(laz[p]!=0){
int l=p<<1,r=(p<<1)|1;
laz[l]+=laz[p];
laz[r]+=laz[p];
t[l].data+=laz[p]*(t[l].r-t[l].l+1);
t[r].data+=laz[p]*(t[r].r-t[r].l+1);
laz[p]=0;
}
}
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
if(l==r){
t[p].data=val[pre[l]];
return ;
}
int mid=l+r>>1;
build(p<<1,l,mid);
build((p<<1)|1,mid+1,r);
up(p);
}
//计算出前四个量 fa , d , size , son
void dfs1(int x,int f){
fa[x]=f;
d[x]=d[f]+1;
size[x]=1;
for(int i=head[x];i ;i=next[i]){
int v=ver[i];
if(v==f) continue;
dfs1(v,x);
size[x]+=size[v];
if(son[x]==-1||size[son[x]]<size[v])
son[x]=v;
}
}
//计算后三个量
void dfs2(int x,int tp){
top[x]=tp;
dfn[x]=++cnt;
pre[cnt]=x;
if(son[x]==-1) return ;
if(son[x])
dfs2(son[x],tp);
for(int i=head[x]; i ;i=next[i]){
int v=ver[i];
if(v==fa[x]||v==son[x]) continue;
dfs2(v,v);
}
}
long long sum(int p,int l,int r){
if(l<=t[p].l&&r>=t[p].r){
return t[p].data;
}
push_down(p);
long long ans=0;
int mid=t[p].l+t[p].r>>1;
if(l<=mid)
ans+=sum(p<<1,l,r);
if(r>mid)
ans+=sum((p<<1)|1,l,r);
return ans;
}
void add(int p,int l,int r,long long c){
if(l<=t[p].l&&r>=t[p].r){
laz[p]+=c;
t[p].data+=c*(t[p].r-t[p].l+1);
return ;
}
push_down(p);
int mid=t[p].l+t[p].r>>1;
if(l<=mid)
add(p<<1,l,r,c);
if(r>mid)
add((p<<1)|1,l,r,c);
up(p);
}
long long getsum(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=sum(1,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=sum(1,dfn[x],dfn[y]);
return ans;
}
int main(){
memset(son,-1,sizeof(son));
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
for(int i=1;i<=n-1;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++){
int opt,x;
long long a;
scanf("%d%d",&opt,&x);
if(opt==1){
scanf("%lld",&a);
add(1,dfn[x],dfn[x],a);
}
if(opt==2){
scanf("%lld",&a);
add(1,dfn[x],dfn[x]+size[x]-1,a);
}
if(opt==3){
printf("%lld\n",getsum(1,x));
}
}
return 0;
}