dfs序与树上差分
引入
dfs序就是用递归的方法遍历一棵树的顺序。
这是一个括号序列,有助于解决树上的很多问题。
例1 DFS 序 1
题目描述
这是一道模板题。
给一棵有根树,这棵树由编号为
1...
N
1...N
1...N 的
N
N
N 个结点组成。根结点的编号为
R
R
R。每个结点都有一个权值,结点
i
i
i 的权值为
v
i
v_i
vi 。
接下来有
M
M
M 组操作,操作分为两类:
1 a x
,表示将结点 a a a 的权值增加 x x x;2 a
,表示求结点 a a a 的子树上所有结点的权值之和。
输入格式
第一行有三个整数
N
,
M
N,M
N,M 和
R
R
R。
第二行有
N
N
N 个整数,第
i
i
i 个整数表示
v
i
v_i
vi 。
在接下来的
N
−
1
N-1
N−1 行中,每行两个整数,表示一条边。
在接下来的
M
M
M 行中,每行一组操作。
输出格式
对于每组 2 a
操作,输出一个整数,表示「以结点
a
a
a 为根的子树」上所有结点的权值之和。
样例输入 1
10 14 9
12 -6 -4 -3 12 8 9 6 6 2
8 2
2 10
8 6
2 7
7 1
6 3
10 9
2 4
10 5
1 4 -1
2 2
1 7 -1
2 10
1 10 5
2 1
1 7 -5
2 5
1 1 8
2 7
1 8 8
2 2
1 5 5
2 6
样例输出 1
21
34
12
12
23
31
4
数据范围与提示
1 ⩽ N , M ⩽ 1 0 6 , 1 ⩽ R ⩽ N , − 1 0 6 ⩽ v i , x ⩽ 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6. 1⩽N,M⩽106,1⩽R⩽N,−106⩽vi,x⩽106.
解析
操作:
-
1.给某个点增加x
-
我们只要给对应的序号的数加上x就可以了。
-
2.询问子树之和
-
一颗子树的dfs序是连续的,对应了一个连续的 区间,所以我们查询区间和。
-
维护一个树状数组即可。
代码
#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define R register int
#define re(i,a,b) for(R i=a; i<=b; i++)
#define ms(i,a) memset(a,i,sizeof(a))
#define MAX(a,b) (((a)>(b)) ? (a):(b))
#define MIN(a,b) (((a)<(b)) ? (a):(b))
using namespace std;
typedef long long LL;
int const N=1000005;
int n,m,r,cnt,sum;
int a[N],tin[N],tout[N],h[N];
LL s[N];
struct Edge{
int to,nt;
} e[N<<1];
inline void add(int a,int b) {
e[++cnt].to=b,e[cnt].nt=h[a],h[a]=cnt;
e[++cnt].to=a,e[cnt].nt=h[b],h[b]=cnt;
}
inline void Add(int x,LL v) {
while(x<=n) s[x]+=v,x+=x&-x;
}
inline LL getsum(int x) {
LL ret=0;
while(x) ret+=s[x],x-=x&-x;
return ret;
}
void dfs(int x,int fa) {
tin[x]=++sum;
for(int i=h[x]; i; i=e[i].nt) {
int v=e[i].to;
if(v==fa) continue;
dfs(v,x);
}
tout[x]=sum;
}
int main() {
scanf("%d%d%d",&n,&m,&r);
for(int i=1; i<=n; i++) scanf("%d",&a[i]);
for(int i=1; i<=n-1; i++) {
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
}
dfs(r,r);
for(int i=1; i<=n; i++) Add(tin[i],a[i]);
while(m--) {
int k,x,y;
scanf("%d",&k);
if(k==1) {
scanf("%d%d",&x,&y);
Add(tin[x],y);
} else {
scanf("%d",&x);
int l=tin[x];
int r=tout[x];
printf("%lld\n",getsum(r)-getsum(l-1));
}
}
return 0;
}
例2 DFS 序 2
题目描述
这是一道模板题。
给一棵有根树,这棵树由编号为
1...
N
1...N
1...N 的
N
N
N 个结点组成。根结点的编号为
R
R
R 。每个结点都有一个权值,结点
i
i
i 的权值为
v
i
v_i
vi 。
接下来有
M
M
M 组操作,操作分为两类:
1 a x
,表示将结点 a a a 的子树上所有结点的权值增加 x x x;2 a
,表示求结点 a a a 的子树上所有结点的权值之和。
输入格式
第一行有三个整数
N
,
M
N,M
N,M 和
R
R
R。
第二行有
N
N
N 个整数,第
i
i
i 个整数表示
v
i
v_i
vi。
在接下来的
N
−
1
N-1
N−1 行中,每行两个整数,表示一条边。
在接下来的
M
M
M 行中,每行一组操作。
输出格式
对于每组 2 a
操作,输出一个整数,表示「以结点
a
a
a 为根的子树」上所有结点的权值之和。
样例输入
10 14 9
12 -6 -4 -3 12 8 9 6 6 2
8 2
2 10
8 6
2 7
7 1
6 3
10 9
2 4
10 5
1 4 -1
2 2
1 7 -1
2 10
1 10 5
2 1
1 7 -5
2 5
1 1 8
2 7
1 8 8
2 2
1 5 5
2 6
样例输出
21
33
16
17
27
76
30
数据范围与提示
1 ⩽ N , M ⩽ 1 0 6 , 1 ⩽ R ⩽ N , − 1 0 6 ⩽ v i , x ⩽ 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6. 1⩽N,M⩽106,1⩽R⩽N,−106⩽vi,x⩽106.
解析
操作:
- 1.子树加
- 子树对应了一个连续的区间,那么就是一个区间修改
- 2.子树查询
- 区间查询
- 维护一个数组数组
代码
#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define R register int
#define re(i,a,b) for(R i=a; i<=b; i++)
#define ms(i,a) memset(a,i,sizeof(a))
#define MAX(a,b) (((a)>(b)) ? (a):(b))
#define MIN(a,b) (((a)<(b)) ? (a):(b))
#define lowbit(x) ((x) & (-x))
using namespace std;
typedef long long LL;
int const N=1000005;
int n,m,r,cnt,sum;
int a[N],tin[N],tout[N],h[N],id[N];
LL s[N],ss[N];
void read(int &x) {
x=0;
char c=0;
int w=0;
while (!isdigit(c)) w|=c=='-',c=getchar();
while (isdigit(c)) x=x*10+(c^48),c=getchar();
if(w) x = -x;
}
struct edge {
int to, nt;
} e[N << 1];
void add(int a, int b) {
e[++cnt].to = b;
e[cnt].nt = h[a];
h[a] = cnt;
e[++cnt].to = a;
e[cnt].nt = h[b];
h[b] = cnt;
}
void dfs(int x, int fa) {
tin[x] = ++sum;
id[sum] = x;
for (int i = h[x]; i; i = e[i].nt) {
int v = e[i].to;
if (v == fa)
continue;
dfs(v, x);
}
tout[x] = sum;
}
void Add(int x, int v) {
for (int i = x; i <= n; i += lowbit(i)) {
s[i] += v;
ss[i] += (LL)v * (x - 1);
}
}
LL getsum(int x) {
LL ret = 0;
for (int i = x; i; i -= lowbit(i)) {
ret += (LL)x * s[i];
ret -= ss[i];
}
return ret;
}
int main() {
read(n);
read(m);
read(r);
for (int i = 1; i <= n; i++) read(a[i]);
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
}
dfs(r, r);
for (int i = 1; i <= n; i++) Add(i, a[id[i]] - a[id[i - 1]]);
while (m--) {
int k, x, y;
scanf("%d", &k);
if (k == 1) {
scanf("%d%d", &x, &y);
int l = tin[x];
int r = tout[x];
Add(l, y);
Add(r + 1, -y);
} else {
scanf("%d", &x);
int l = tin[x];
int r = tout[x];
printf("%lld\n", getsum(r) - getsum(l - 1));
}
}
return 0;
}
例3 DFS 序 3,树上差分 1
题目描述
这是一道模板题。
不保证无快读的程序能过。请务必使用快读。
给一棵有根树,这棵树由编号为
1
…
N
1…N
1…N 的
N
N
N 个结点组成。根结点的编号为
R
R
R。每个结点都有一个权值,结点
i
i
i 的权值为
v
i
v_i
vi。
接下来有 M 组操作,操作分为三类:
1 a b x
,表示将「结点 a a a 到结点 b b b 的简单路径」上所有结点的权值都增加 x x x;2 a
,表示求结点 a a a 的权值。3 a
,表示求 a a a 的子树上所有结点的权值之和。
输入格式
第一行有三个整数
N
,
M
N,M
N,M 和
R
R
R。
第二行有
N
N
N个整数,第
i
i
i 个整数表示
v
i
v_i
vi。
在接下来的
N
−
1
N−1
N−1 行中,每行两个整数,表示一条边。
在接下来的
M
M
M 行中,每行一组操作。
输出格式
对于每组 2 a
操作,输出一个整数,表示结点 a
的权值。
样例输入 1
10 15 3
4 8 -2 -4 -7 -7 -9 5 2 5
3 9
3 4
4 5
4 8
8 7
3 6
8 2
9 10
2 1
2 5
1 4 7 3
1 7 2 6
1 6 7 -7
2 1
1 10 10 -9
2 4
1 2 9 -8
2 6
1 10 5 -2
1 4 4 6
1 6 1 3
1 1 10 2
1 9 2 0
2 7
样例输出 1
-7
4
-8
-14
-7
样例输入 2
10 17 3
5 1 -7 -9 -5 3 -7 -5 3 3
1 8
8 7
7 6
8 3
6 10
7 2
6 9
1 4
6 5
2 9
1 10 4 -2
2 8
1 1 10 -2
3 5
1 10 6 -3
3 1
1 6 5 9
2 8
1 4 5 1
2 10
1 2 5 6
1 2 6 0
1 2 7 -5
1 4 9 6
1 10 1 0
3 2
样例输出 2
3
-7
-5
-10
-9
-4
2
数据范围与提示
40
40
40% 的数据不含操作 3
。
对于所有数据,
1
⩽
N
,
M
⩽
1
0
6
,
1
⩽
R
⩽
N
,
−
1
0
6
⩽
v
i
,
x
⩽
1
0
6
.
1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.
1⩽N,M⩽106,1⩽R⩽N,−106⩽vi,x⩽106.
解析
本题可以做树剖,但是树剖的时间复杂度是: O ( n l o g n ∗ l o g n ) O(nlogn*logn) O(nlogn∗logn)
我们可以做树上差分
- 对于操作1: a到b路径上每个数增加x,我们可 以给a和b打一个+x的标记,lca打一个-x的标记,
- lca的父亲打一个-x的标记。这样就可以处理操 作2和操作3了。
- 对于操作2: 查询点的值,就是查询这个子树 的和。
- 对于操作3:我们考虑子树里面每个修改对答案 的贡献,假设我们要查询以u为根的子树,子树 里面有一个点v,我们考虑v对答案的贡献就是 v a l [ v ] ∗ ( d e p [ v ] − d e p [ u ] + 1 ) val[v]*(dep[v]-dep[u]+1) val[v]∗(dep[v]−dep[u]+1),拆开以后就是 v a l [ v ] ∗ d e p [ v ] − v a l [ v ] ∗ ( d e p [ u ] − 1 ) val[v]*dep[v]-val[v]*(dep[u]-1) val[v]∗dep[v]−val[v]∗(dep[u]−1),分别维护两个树状数组即可。
代码
#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define R register int
#define re(i,a,b) for(R i=a; i<=b; i++)
#define ms(i,a) memset(a,i,sizeof(a))
#define MAX(a,b) (((a)>(b)) ? (a):(b))
#define MIN(a,b) (((a)<(b)) ? (a):(b))
#define lowbit(x) ((x) & (-x))
using namespace std;
typedef long long LL;
int const N=1000005;
inline void read(int &x){
x=0;
char c=0;
int w=0;
while (!isdigit(c)) w|=c=='-',c=getchar();
while (isdigit(c)) x=x*10+(c^48),c=getchar();
if(w) x=-x;
}
struct edge{
int to,nt;
} e[N<<1];
int n,m,rt,cnt,sum;
int dep[N],tin[N],tout[N],h[N],a[N];
int f[N][20];
LL s[N],ss[N],t[N];
void add(int a,int b){
e[++cnt].to=b,e[cnt].nt=h[a],h[a]=cnt;
e[++cnt].to=a,e[cnt].nt=h[b],h[b]=cnt;
}
void dfs(int x,int fa,int d){
dep[x]=d;
tin[x]=++sum;
f[x][0]=fa;
for(int i=h[x]; i; i=e[i].nt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,x,d+1);
}
tout[x]=sum;
}
void Add(int x,int v){
for(int i=x;i<=n;i+=lowbit(i))
t[i]+=v;
}
int ancestor(int x,int y){
return tin[x]<=tin[y] && tout[y]<=tout[x];
}
int lca(int x,int y){
if(ancestor(x,y)) return x;
if(ancestor(y,x)) return y;
for(int i=19; i>=0; i--)
if(!ancestor(f[x][i],y))
x=f[x][i];
return f[x][0];
}
LL getsum(int x,LL s[]){
LL ret=0;
for(int i=x; i; i-=lowbit(i))
ret+=s[i];
return ret;
}
inline void Add2(int x,LL d,int v){
for(int i=x;i<=n;i+=lowbit(i)) {
s[i]+=v;
ss[i]+=d*v;
}
}
int main(){
read(n);
read(m);
read(rt);
for(int i=1; i<=n; i++) read(a[i]);
for(int i=1; i<n; i++) {
int x,y;
read(x);
read(y);
add(x,y);
}
dfs(rt,rt,1);
for(int j=1; j<20; j++) for(int i=1; i<=n; i++) f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<=n;i++) Add(tin[i],a[i]);
while(m--) {
int k,l,r,x;
read(k);
if(k==1){
read(l);
read(r);
read(x);
Add2(tin[l],dep[l],x);
Add2(tin[r],dep[r],x);
int t=lca(l,r);
Add2(tin[t],dep[t],-x);
if(t!=rt) Add2(tin[f[t][0]],dep[t]-1,-x);
} else if(k==2) {
read(l);
printf("%lld\n",getsum(tout[l],s)-getsum(tin[l]-1,s)+a[l]);
} else {
read(l);
LL t1=getsum(tout[l],t)-getsum(tin[l]-1,t);
LL t2=getsum(tout[l],ss)-getsum(tin[l]-1,ss);
LL t3=getsum(tout[l],s)-getsum(tin[l]-1,s);
printf("%lld\n",t2-t3*(dep[l]-1)+t1);
}
}
return 0;
}
例4 DFS序4
题目描述
这是一道模板题。
本题严重卡常,请务必使用 fread
快读,不保证无快读的程序能过(虽然标程没用快读)。另外,建议使用 Tarjan 或树剖求 LCA。
给一棵有根树,这棵树由编号为
1
…
N
1…N
1…N 的
N
N
N 个结点组成。根结点的编号为
R
R
R。每个结点都有一个权值,结点
i
i
i 的权值为
v
i
v_i
vi。
接下来有
M
M
M 组操作,操作分为三类:
1 a x
,表示将结点 a a a 的权值增加 x x x;2 a x
,表示将 a a a 的子树上所有结点的权值增加 x x x;3 a b
,表示求「结点 a a a 到结点 b b b 的简单路径」上所有结点的权值之和。
输入格式
第一行有三个整数
N
,
M
N,M
N,M 和
R
R
R。
第二行有
N
N
N 个整数,第
i
i
i 个整数表示
v
i
v_i
vi。
在接下来的
N
−
1
N−1
N−1 行中,每行两个整数,表示一条边。
在接下来的
M
M
M 行中,每行一组操作。
输出格式
对于每组 3 a b
操作,输出一个整数,表示「结点
a
a
a 到结点
b
b
b 的简单路径」上所有结点的权值之和(含结点
a
,
b
a, b
a,b)。
样例输入 1
10 13 5
-2 -7 0 2 -9 -2 -4 9 8 -1
9 8
9 4
9 2
4 10
10 7
10 6
2 1
8 3
7 5
3 8 6
1 7 -8
1 5 -9
1 5 -4
1 4 -2
1 2 -1
3 5 1
1 7 1
3 1 3
1 1 -3
3 10 2
1 1 -8
3 8 4
样例输出 1
16
-37
7
-1
17
样例输入 2
10 16 4
-13 -11 5 4 18 13 14 -8 -8 14
4 1
4 10
10 2
2 8
4 7
1 6
8 5
1 3
2 9
3 5 10
1 5 -5
2 9 -4
3 8 6
1 5 -8
2 8 -5
3 8 7
1 9 0
2 10 -3
3 7 6
2 9 -4
2 8 2
3 4 4
2 1 8
1 6 5
3 8 3
样例输出 2
13
-1
8
18
4
-5
数据范围与提示
40
40
40% 的数据不含操作 2
。
1
⩽
N
,
M
⩽
1
0
6
,
1
⩽
R
⩽
N
,
−
1
0
6
⩽
v
i
,
x
⩽
1
0
6
.
1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.
1⩽N,M⩽106,1⩽R⩽N,−106⩽vi,x⩽106.
代码
#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define R register int
#define re(i,a,b) for(R i=a; i<=b; i++)
#define ms(i,a) memset(a,i,sizeof(a))
#define MAX(a,b) (((a)>(b)) ? (a):(b))
#define MIN(a,b) (((a)<(b)) ? (a):(b))
#define lowbit(x) ((x) & (-x))
using namespace std;
typedef long long LL;
namespace IN {
#include <cctype>
#include <cstdio>
#define bsiz 1000000
int sta[30];
char buf[bsiz], pbuf[bsiz], *p = pbuf, *s = buf, *t = buf;
#define mgetc() (s == t && (t = (s = buf) + fread(buf, 1, bsiz, stdin), s == t) ? EOF : *s++)
inline int read() {
register char ch;
register int res=0, p;
while (!isdigit(ch = mgetc()) && (ch ^ '-'));
p = ch == '-' ? ch = mgetc(), -1 : 1;
while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = mgetc();
return res*p;
}
}
const int N=1e6+5;
struct edge{
int to,nt;
} e[N<<1];
int cnt,sum,a[N],h[N],n,m,rt,tin[N],tout[N],f[N][20],dep[N];
LL s[N],ss[N],val[N],d[N];
void add(int a,int b){
e[++cnt].to=b; e[cnt].nt=h[a]; h[a]=cnt;
e[++cnt].to=a; e[cnt].nt=h[b]; h[b]=cnt;
}
void dfs(int x,int fa,int d,LL tot) {
tin[x]=++sum;
f[x][0]=fa;
val[x]=tot;
dep[x]=d;
for(int i=h[x]; i; i=e[i].nt) {
int v=e[i].to;
if(v==fa) continue;
dfs(v,x,d+1,tot+a[v]);
}
tout[x]=sum;
}
int inline ancestor(int x,int y) {
return tin[x]<=tin[y] && tout[y]<=tout[x];
}
int lca(int x,int y) {
if(ancestor(x,y)) return x;
if(ancestor(y,x)) return y;
for(int i=19; i>=0; i--)
if(!ancestor(f[x][i],y))
x=f[x][i];
return f[x][0];
}
void Add(int x,LL v,LL s[]){
for(int i=x; i<=n; i+=lowbit(i))
s[i]+=v;
}
LL getsum(int x,LL s[]){
LL ret=0;
for(int i=x; i; i-=lowbit(i))
ret+=s[i];
return ret;
}
int main() {
n=IN::read();
m=IN::read();
rt=IN::read();
for(int i=1; i<=n; i++)
a[i]=IN::read();
for(int i=1; i<n; i++) {
int x,y;
x=IN::read();
y=IN::read();
add(x,y);
}
dfs(rt,rt,1,a[rt]);
for(int j=1; j<20; j++)
for(int i=1; i<=n; i++)
f[i][j]=f[f[i][j-1]][j-1];
while(m--) {
int k,a,b,x;
k=IN::read();
if(k==1) {
a=IN::read();
x=IN::read();
Add(tin[a],x,d);
Add(tout[a]+1,-x,d);
} else if(k==2) {
a=IN::read();
x=IN::read();
Add(tin[a],(LL)x*dep[a],s);
Add(tout[a]+1,-(LL)x*dep[a],s);
Add(tin[a],x,ss);
Add(tout[a]+1,-x,ss);
} else {
a=IN::read();
b=IN::read();
int t=lca(a,b);
LL t1=val[a]+val[b]-val[t];
if(t!=rt) t1-=val[f[t][0]];
LL t2=getsum(tin[a],d)+getsum(tin[b],d)-getsum(tin[t],d);
if(t!=rt) t2-=getsum(tin[f[t][0]],d);
LL t3=getsum(tin[a],s)+getsum(tin[b],s)-getsum(tin[t],s);
if(t!=rt) t3-=getsum(tin[f[t][0]],s );
LL t4=getsum(tin[a],ss)*(dep[a]+1)+getsum(tin[b],ss)*(dep[b]+1)-getsum(tin[t],ss)*(dep[t]+1);
if(t!=rt) t4-=getsum(tin[f[t][0]],ss)*(dep[f[t][0]]+1);
printf("%lld\n",t1+t2-t3+t4);
}
}
return 0;
}